From ab5ca285a4583842a2080f2c65937b5b8403036e Mon Sep 17 00:00:00 2001
From: Yangge Li <li213@illinois.edu>
Date: Fri, 24 Jun 2022 10:13:01 -0500
Subject: [PATCH] working on integrating improved parser with verification

---
 demo/demo3.py                                 | 22 +++---
 .../scene_verifier/analysis/verifier.py       |  3 +
 .../scene_verifier/automaton/guard.py         | 19 ++---
 .../scene_verifier/scenario/scenario.py       | 71 +++++++++++++------
 requirements.txt                              |  2 +-
 setup.py                                      |  2 +-
 6 files changed, 76 insertions(+), 43 deletions(-)

diff --git a/demo/demo3.py b/demo/demo3.py
index a459d2ec..c3372863 100644
--- a/demo/demo3.py
+++ b/demo/demo3.py
@@ -42,35 +42,35 @@ class State:
 
 
 if __name__ == "__main__":
-    input_code_name = './example_controller4.py'
+    input_code_name = './demo/example_controller4.py'
     scenario = Scenario()
 
     car = CarAgent('car1', file_name=input_code_name)
     scenario.add_agent(car)
     car = NPCAgent('car2')
     scenario.add_agent(car)
-    car = NPCAgent('car3')
-    scenario.add_agent(car)
-    car = NPCAgent('car4')
-    scenario.add_agent(car)
+    # car = NPCAgent('car3')
+    # scenario.add_agent(car)
+    # car = NPCAgent('car4')
+    # scenario.add_agent(car)
     tmp_map = SimpleMap3()
     scenario.set_map(tmp_map)
     scenario.set_init(
         [
             [[0, -0.2, 0, 1.0],[0.01, 0.2, 0, 1.0]],
             [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], 
-            [[20, 3, 0, 0.5],[20, 3, 0, 0.5]], 
-            [[30, 0, 0, 0.5],[30, 0, 0, 0.5]], 
+            # [[20, 3, 0, 0.5],[20, 3, 0, 0.5]], 
+            # [[30, 0, 0, 0.5],[30, 0, 0, 0.5]], 
         ],
         [
             (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
             (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
-            (VehicleMode.Normal, LaneMode.Lane0, LaneObjectMode.Vehicle),
-            (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
+            # (VehicleMode.Normal, LaneMode.Lane0, LaneObjectMode.Vehicle),
+            # (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
         ]
     )
-    traces = scenario.simulate(70, 0.05)
-    # traces = scenario.verify(70, 0.05)
+    # traces = scenario.simulate(70, 0.05)
+    traces = scenario.verify(70, 0.05)
 
     # fig = plt.figure(2)
     # fig = plot_map(tmp_map, 'g', fig)
diff --git a/dryvr_plus_plus/scene_verifier/analysis/verifier.py b/dryvr_plus_plus/scene_verifier/analysis/verifier.py
index 56a131b7..3db38a57 100644
--- a/dryvr_plus_plus/scene_verifier/analysis/verifier.py
+++ b/dryvr_plus_plus/scene_verifier/analysis/verifier.py
@@ -18,6 +18,7 @@ class Verifier:
         self,
         init_list: List[float],
         init_mode_list: List[str],
+        static_list: List[str],
         agent_list:List[BaseAgent], 
         transition_graph, 
         time_horizon, 
@@ -29,6 +30,8 @@ class Verifier:
             root.init[agent.id] = init_list[i]
             init_mode = [elem.name for elem in init_mode_list[i]]
             root.mode[agent.id] = init_mode 
+            init_static = [elem.name for elem in static_list[i]]
+            root.static[agent.id] = init_static
             root.agent[agent.id] = agent 
             root.type = 'reachtube'
         self.reachtube_tree_root = root 
diff --git a/dryvr_plus_plus/scene_verifier/automaton/guard.py b/dryvr_plus_plus/scene_verifier/automaton/guard.py
index 4fda66aa..7b822806 100644
--- a/dryvr_plus_plus/scene_verifier/automaton/guard.py
+++ b/dryvr_plus_plus/scene_verifier/automaton/guard.py
@@ -13,6 +13,7 @@ from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
 from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane
 from dryvr_plus_plus.scene_verifier.utils.utils import *
 from dryvr_plus_plus.scene_verifier.code_parser.parser import Reduction, ReductionType
+from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent
 
 class LogicTreeNode:
     def __init__(self, data, child = [], val = None, mode_guard = None):
@@ -108,7 +109,7 @@ class GuardExpressionAst:
         # Therefore we are not be able to get free symbols from it
         # Thus we need to replace "==" to something else
 
-        symbols_map = {v: k for k, v in self.cont_variables.items()}
+        symbols_map = {v: k for k, v in self.cont_variables.items() if k in guard_str}
 
         for vars in reversed(self.cont_variables):
             guard_str = guard_str.replace(vars, self.cont_variables[vars])
@@ -502,7 +503,7 @@ class GuardExpressionAst:
             res = res and tmp 
         return res
             
-    def _evaluate_guard_disc(self, root, agent, disc_var_dict, cont_var_dict, lane_map):
+    def _evaluate_guard_disc(self, root, agent:BaseAgent, disc_var_dict, cont_var_dict, lane_map):
         """
         Recursively called function to evaluate guard with only discrete variables
         The function will evaluate all guards with discrete variables and replace the nodes with discrete guards by
@@ -578,9 +579,9 @@ class GuardExpressionAst:
                         root = ast.parse('False').body[0].value    
                 else:
                     # TODO-PARSER: Handle This
-                    for mode_name in agent.controller.modes:
+                    for mode_name in agent.controller.mode_defs:
                         # TODO-PARSER: Handle This
-                        if res in agent.controller.modes[mode_name]:
+                        if res in agent.controller.mode_defs[mode_name].modes:
                             res = mode_name+'.'+res
                             break
                     root = ast.parse(str(res)).body[0].value
@@ -593,14 +594,14 @@ class GuardExpressionAst:
             if expr in disc_var_dict:
                 val = disc_var_dict[expr]
                 # TODO-PARSER: Handle This
-                for mode_name in agent.controller.modes:
+                for mode_name in agent.controller.mode_defs:
                     # TODO-PARSER: Handle This
-                    if val in agent.controller.modes[mode_name]:
+                    if val in agent.controller.mode_defs[mode_name].modes:
                         val = mode_name+'.'+val
                         break
                 return val, root
             # TODO-PARSER: Handle This
-            elif root.value.id in agent.controller.modes:
+            elif root.value.id in agent.controller.mode_defs:
                 return expr, root
             else:
                 return True, root
@@ -622,9 +623,9 @@ class GuardExpressionAst:
             if expr in disc_var_dict:
                 val = disc_var_dict[expr]
                 # TODO-PARSER: Handle This
-                for mode_name in agent.controller.modes:
+                for mode_name in agent.controller.mode_defs:
                     # TODO-PARSER: Handle This
-                    if val in agent.controller.modes[mode_name]:
+                    if val in agent.controller.mode_defs[mode_name].modes:
                         val = mode_name + '.' + val 
                         break 
                 return val, root
diff --git a/dryvr_plus_plus/scene_verifier/scenario/scenario.py b/dryvr_plus_plus/scene_verifier/scenario/scenario.py
index 19b8820c..5caa42bf 100644
--- a/dryvr_plus_plus/scene_verifier/scenario/scenario.py
+++ b/dryvr_plus_plus/scene_verifier/scenario/scenario.py
@@ -2,9 +2,9 @@ from typing import Tuple, List, Dict, Any
 import copy
 import itertools
 import warnings
+from collections import defaultdict
 
 import numpy as np
-from sympy import Q
 
 from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent
 from dryvr_plus_plus.scene_verifier.automaton.guard import GuardExpressionAst
@@ -90,6 +90,7 @@ class Scenario:
     def verify(self, time_horizon, time_step):
         init_list = []
         init_mode_list = []
+        static_list = []
         agent_list = []
         for agent_id in self.agent_dict:
             init = self.init_dict[agent_id]
@@ -98,14 +99,36 @@ class Scenario:
                 init = [init, init]
             init_list.append(init)
             init_mode_list.append(self.init_mode_dict[agent_id])
+            static_list.append(self.static_dict[agent_id])
             agent_list.append(self.agent_dict[agent_id])
-        return self.verifier.compute_full_reachtube(init_list, init_mode_list, agent_list, self, time_horizon, time_step, self.map)
+        return self.verifier.compute_full_reachtube(init_list, init_mode_list, static_list, agent_list, self, time_horizon, time_step, self.map)
+
+    def apply_reset(self, agent: BaseAgent, reset_list, all_agent_state) -> Tuple[str, np.ndarray]:
+        # reset_expr = ResetExpression(reset_list)
+        # continuous_variable_dict, discrete_variable_dict, _ = self.sensor.sense(self, agent, all_agent_state, self.map)
+        # dest = reset_expr.get_dest(agent, all_agent_state[agent.id], discrete_variable_dict, self.map)
+        # rect = reset_expr.apply_reset_continuous(agent, continuous_variable_dict, self.map)
+        # return dest, rect
+        dest = []
+        rect = []
+        
+        agent_state, agent_mode, agent_static = all_agent_state[agent.id]
 
-    def apply_reset(self, agent, reset_list, all_agent_state) -> Tuple[str, np.ndarray]:
-        reset_expr = ResetExpression(reset_list)
-        continuous_variable_dict, discrete_variable_dict, _ = self.sensor.sense(self, agent, all_agent_state, self.map)
-        dest = reset_expr.get_dest(agent, all_agent_state[agent.id], discrete_variable_dict, self.map)
-        rect = reset_expr.apply_reset_continuous(agent, continuous_variable_dict, self.map)
+        # First get the transition destinations
+        dest = copy.deepcopy(agent_mode)
+        possible_dest = [[elem] for elem in dest]
+        ego_type = agent.controller.ego_type
+        for reset in reset_list:
+            reset_variable = reset.var
+            expr = reset.expr
+            if "mode" in reset_variable:
+                for var_loc, discrete_variable_ego in enumerate(agent.controller.state_defs[ego_type].disc):
+                    if discrete_variable_ego == reset_variable:
+                        break
+                if 'map' in expr:
+                    for var in discrete_variable_dict
+
+        # Then get the transition updated rect
         return dest, rect
 
     def apply_cont_var_updater(self,cont_var_dict, updater):
@@ -131,11 +154,11 @@ class Scenario:
             agent_mode = node.mode[agent_id]
             # TODO-PARSER: update how we get all next modes
             # The getNextModes function will return 
-            paths = agent.controller.getNextModes()
             state_dict = {}
             for tmp in node.agent:
                 state_dict[tmp] = (node.trace[tmp][0], node.mode[tmp], node.static[tmp])
             cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
+            paths = agent.controller.getNextModes()
             for path in paths:
                 guard_list = path[0]
                 reset = path[1]
@@ -251,16 +274,11 @@ class Scenario:
             
             cont_var_dict_template, discrete_variable_dict, length_dict = self.sensor.sense(self, agent, state_dict, self.map)
             # TODO-PARSER: Get equivalent for this function
-            paths = agent.controller.getNextModes(agent_mode)
+            paths = agent.controller.getNextModes()
             for path in paths:
                 # Construct the guard expression
-                guard_list = []
-                reset_list = []
-                for item in path:
-                    if isinstance(item, Guard):
-                        guard_list.append(item)
-                    elif isinstance(item, Reset):
-                        reset_list.append(item)
+                guard_list = path[0]
+                reset = path[1]
                 guard_expression = GuardExpressionAst(guard_list)
                 
                 cont_var_updater = guard_expression.parse_any_all_new(cont_var_dict_template, discrete_variable_dict, length_dict)
@@ -269,9 +287,9 @@ class Scenario:
                 if not guard_can_satisfied:
                     continue
                 if agent_id not in agent_guard_dict:
-                    agent_guard_dict[agent_id] = [(guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset_list)]
+                    agent_guard_dict[agent_id] = [(guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset)]
                 else:
-                    agent_guard_dict[agent_id].append((guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset_list))
+                    agent_guard_dict[agent_id].append((guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset))
 
         trace_length = int(len(list(node.trace.values())[0])/2)
         guard_hits = []
@@ -285,10 +303,11 @@ class Scenario:
             
             for agent_id in agent_guard_dict:
                 agent:BaseAgent = self.agent_dict[agent_id]
-                agent_state, agent_mode = state_dict[agent_id]
+                agent_state, agent_mode, agent_static = state_dict[agent_id]
                 agent_state = agent_state[1:]
                 continuous_variable_dict, _, _ = self.sensor.sense(self, agent, state_dict, self.map)
-                for guard_expression, continuous_variable_updater, discrete_variable_dict, reset_list in agent_guard_dict[agent_id]:
+                resets = defaultdict(list)
+                for guard_expression, continuous_variable_updater, discrete_variable_dict, reset in agent_guard_dict[agent_id]:
                     new_cont_var_dict = copy.deepcopy(continuous_variable_dict)
                     one_step_guard:GuardExpressionAst = copy.deepcopy(guard_expression)
 
@@ -298,8 +317,17 @@ class Scenario:
                         continue
                     guard_satisfied, is_contained = one_step_guard.evaluate_guard_cont(agent, new_cont_var_dict, self.map)
                     any_contained = any_contained or is_contained
+                    # TODO: Can we also store the cont and disc var dict so we don't have to call sensor again?
                     if guard_satisfied:
-                        hits.append((agent_id, guard_list, reset_list))
+                        reset_expr = ResetExpression(reset)
+                        resets[reset_expr.var].append(reset_expr)
+                # Perform combination over all possible resets to generate all possible real resets
+                combined_reset_list = list(itertools.product(*resets.values()))
+                if len(combined_reset_list)==1 and combined_reset_list[0]==():
+                    continue
+                for i in range(len(combined_reset_list)):
+                    # a list of reset expression
+                    hits.append((agent_id, guard_expression, combined_reset_list[i]))
             if hits != []:
                 guard_hits.append((hits, state_dict, idx))
                 guard_hit_bool = True 
@@ -312,6 +340,7 @@ class Scenario:
         reset_idx_dict = {}
         for hits, all_agent_state, hit_idx in guard_hits:
             for agent_id, guard_list, reset_list in hits:
+                # TODO: Need to change this function to handle the new reset expression and then I am done 
                 dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
                 if agent_id not in reset_dict:
                     reset_dict[agent_id] = {}
diff --git a/requirements.txt b/requirements.txt
index b273c15c..49471cda 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
 numpy~=1.22.1
-scipy~=1.8.0
+scipy~=1.6.1
 matplotlib~=3.4.2
 polytope~=0.2.3
 pyvista~=0.32.1
diff --git a/setup.py b/setup.py
index 536556a8..4ba70ef5 100644
--- a/setup.py
+++ b/setup.py
@@ -13,7 +13,7 @@ setup(
     python_requires='>=3.8',
     install_requires=[
         "numpy~=1.22.1",
-        "scipy~=1.8.0",
+        "scipy~=1.6.1",
         "matplotlib~=3.4.2",
         "polytope~=0.2.3",
         "pyvista~=0.32.1",
-- 
GitLab