From 333f3a96942f09a0f3ef9f5dec9a882177733b43 Mon Sep 17 00:00:00 2001
From: unknown <lyg1597@gmail.com>
Date: Mon, 2 May 2022 00:08:38 -0500
Subject: [PATCH] working on verification, need to fix ast.Unary object

---
 example_two_car_lane_switch.py | 70 ++++++++++++++++++++--------------
 ourtool/analysis/verifier.py   |  9 +++--
 ourtool/automaton/guard.py     | 10 ++++-
 ourtool/automaton/reset.py     | 10 ++++-
 ourtool/scenario/scenario.py   | 21 ++++++----
 plotter/__init__.py            | 22 +++++++++++
 plotter/parser.py              | 38 ++++++++++++++++++
 plotter/plotter2D.py           | 59 ++++++++++++++++++++++++++++
 plotter/plotter3D.py           | 68 +++++++++++++++++++++++++++++++++
 9 files changed, 264 insertions(+), 43 deletions(-)
 create mode 100644 plotter/__init__.py
 create mode 100644 plotter/parser.py
 create mode 100644 plotter/plotter2D.py
 create mode 100644 plotter/plotter3D.py

diff --git a/example_two_car_lane_switch.py b/example_two_car_lane_switch.py
index 74e9a88c..145a1ffc 100644
--- a/example_two_car_lane_switch.py
+++ b/example_two_car_lane_switch.py
@@ -1,8 +1,3 @@
-import matplotlib.pyplot as plt
-from ourtool.agents.car_agent import CarAgent
-import numpy as np
-from user.simple_map import SimpleMap, SimpleMap2
-from ourtool.scenario.scenario import Scenario
 from enum import Enum, auto
 
 from ourtool.map.lane_map import LaneMap
@@ -43,17 +38,25 @@ def controller(ego: State, other: State, lane_map):
             if lane_map.has_right(ego.lane_mode):
                 output.vehicle_mode = VehicleMode.SwitchRight
     if ego.vehicle_mode == VehicleMode.SwitchLeft:
-        if ego.y >= lane_map.lane_geometry(ego.lane_mode)-2.5:
+        if  lane_map.lane_geometry(ego.lane_mode) - ego.y <= -2.5:
             output.vehicle_mode = VehicleMode.Normal
             output.lane_mode = lane_map.left_lane(ego.lane_mode)
     if ego.vehicle_mode == VehicleMode.SwitchRight:
-        if ego.y <= lane_map.lane_geometry(ego.lane_mode)+2.5:
+        if lane_map.lane_geometry(ego.lane_mode)-ego.y >= 2.5:
             output.vehicle_mode = VehicleMode.Normal
             output.lane_mode = lane_map.right_lane(ego.lane_mode)
 
     return output
 
 
+from ourtool.agents.car_agent import CarAgent
+from ourtool.scenario.scenario import Scenario
+from user.simple_map import SimpleMap, SimpleMap2
+from plotter.plotter2D import plot_tree
+
+import matplotlib.pyplot as plt
+import numpy as np
+
 if __name__ == "__main__":
     input_code_name = 'example_two_car_lane_switch.py'
     scenario = Scenario()
@@ -64,33 +67,42 @@ if __name__ == "__main__":
     scenario.add_agent(car)
     scenario.add_map(SimpleMap2())
     scenario.set_init(
-        [[10, 3, 0, 0.5], [0, 3, 0, 1.0]],
         [
-            (VehicleMode.Normal, LaneMode.Lane0),
-            (VehicleMode.Normal, LaneMode.Lane0)
+            [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], 
+            [[0, -0.2, 0, 1.0],[0, 0.2, 0, 1.0]],
+        ],
+        [
+            (VehicleMode.Normal, LaneMode.Lane1),
+            (VehicleMode.Normal, LaneMode.Lane1)
         ]
     )
     # simulator = Simulator()
     # traces = scenario.simulate(40)
     traces = scenario.verify(40)
 
-    plt.plot([0, 40], [3, 3], 'g')
-    plt.plot([0, 40], [0, 0], 'g')
-    plt.plot([0, 40], [-3, -3], 'g')
-
-    queue = [traces]
-    while queue != []:
-        node = queue.pop(0)
-        traces = node.trace
-        # for agent_id in traces:
-        agent_id = 'car2'
-        trace = np.array(traces[agent_id])
-        plt.plot(trace[:, 1], trace[:, 2], 'r')
-
-        agent_id = 'car1'
-        trace = np.array(traces[agent_id])
-        plt.plot(trace[:, 1], trace[:, 2], 'b')
-
-        # if node.child != []:
-        queue += node.child
+    fig = plt.figure()
+    fig = plot_tree(traces, 'car1', 1, [2], 'b', fig)
+    fig = plot_tree(traces, 'car2', 1, [2], 'r', fig)
+
     plt.show()
+
+    # plt.plot([0, 40], [3, 3], 'g')
+    # plt.plot([0, 40], [0, 0], 'g')
+    # plt.plot([0, 40], [-3, -3], 'g')
+
+    # queue = [traces]
+    # while queue != []:
+    #     node = queue.pop(0)
+    #     traces = node.trace
+    #     # for agent_id in traces:
+    #     agent_id = 'car2'
+    #     trace = np.array(traces[agent_id])
+    #     plt.plot(trace[:, 1], trace[:, 2], 'r')
+
+    #     agent_id = 'car1'
+    #     trace = np.array(traces[agent_id])
+    #     plt.plot(trace[:, 1], trace[:, 2], 'b')
+
+    #     # if node.child != []:
+    #     queue += node.child
+    # plt.show()
diff --git a/ourtool/analysis/verifier.py b/ourtool/analysis/verifier.py
index 31cfd7f4..918cc5e2 100644
--- a/ourtool/analysis/verifier.py
+++ b/ourtool/analysis/verifier.py
@@ -40,6 +40,7 @@ class Verifier:
             if remain_time <= 0:
                 continue 
             # For reachtubes not already computed
+            # TODO: can add parallalization for this loop
             for agent_id in node.agent:
                 if agent_id not in node.trace:
                     # Compute the trace starting from initial condition
@@ -101,6 +102,8 @@ class Verifier:
                 verification_queue.append(tmp)
 
             """Truncate trace of current node based on max_end_idx"""
-            for agent_idx in node.agent:
-                node.trace[agent_idx] = node.trace[agent_idx][:(max_end_idx+1)*2]
-        
\ No newline at end of file
+            """Only truncate when there's transitions"""
+            if all_possible_transitions:
+                for agent_idx in node.agent:
+                    node.trace[agent_idx] = node.trace[agent_idx][:(max_end_idx+1)*2]
+        return root
\ No newline at end of file
diff --git a/ourtool/automaton/guard.py b/ourtool/automaton/guard.py
index 5ff33274..ba209420 100644
--- a/ourtool/automaton/guard.py
+++ b/ourtool/automaton/guard.py
@@ -5,6 +5,7 @@ import pickle
 # from ourtool.automaton.hybrid_io_automaton import HybridIoAutomaton
 # from pythonparser import Guard
 import ast
+import copy
 
 from z3 import *
 import sympy
@@ -22,7 +23,7 @@ class GuardExpressionAst:
     def __init__(self, guard_list):
         self.ast_list = []
         for guard in guard_list:
-            self.ast_list.append(guard.ast)
+            self.ast_list.append(copy.deepcopy(guard.ast))
         self.cont_variables = {}
         self.varDict = {'t':Real('t')}
 
@@ -241,6 +242,8 @@ class GuardExpressionAst:
             if any([var in expr for var in disc_var_dict]):
                 left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, lane_map)
                 right, root.comparators[0] = self._evaluate_guard_disc(root.comparators[0], agent, disc_var_dict, lane_map)
+                if isinstance(left, bool) or isinstance(right, bool):
+                    return True, root
                 if isinstance(root.ops[0], ast.GtE):
                     res = left>=right
                 elif isinstance(root.ops[0], ast.Gt):
@@ -280,6 +283,9 @@ class GuardExpressionAst:
                         break
                 return res, root     
         elif isinstance(root, ast.BinOp):
+            # Check left and right in the binop and replace all attributes involving discrete variables
+            left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, lane_map)
+            right, root.right = self._evaluate_guard_disc(root.right, agent, disc_var_dict, lane_map)
             return True, root
         elif isinstance(root, ast.Call):
             expr = astunparse.unparse(root)
@@ -300,6 +306,8 @@ class GuardExpressionAst:
                         root = ast.parse('True').body[0].value
                     else:
                         root = ast.parse('False').body[0].value    
+                else:
+                    root = ast.parse(str(res)).body[0].value
                 return res, root
             else:
                 return True, root
diff --git a/ourtool/automaton/reset.py b/ourtool/automaton/reset.py
index 440f993b..a0e9ddcb 100644
--- a/ourtool/automaton/reset.py
+++ b/ourtool/automaton/reset.py
@@ -1,3 +1,5 @@
+import itertools
+
 import numpy as np 
 
 class ResetExpression:
@@ -74,5 +76,9 @@ class ResetExpression:
                     tmp = tmp[1].split('.')
                     if tmp[0].strip(' ') in agent.controller.modes:
                         possible_dest[i] = [tmp[1]]                            
-            
-        return possible_dest
\ No newline at end of file
+        all_dest = itertools.product(*possible_dest)
+        res = []
+        for dest in all_dest:
+            dest = ','.join(dest)
+            res.append(dest)
+        return res
\ No newline at end of file
diff --git a/ourtool/scenario/scenario.py b/ourtool/scenario/scenario.py
index d6284c13..39060ec4 100644
--- a/ourtool/scenario/scenario.py
+++ b/ourtool/scenario/scenario.py
@@ -146,7 +146,7 @@ class Scenario:
     def check_guard_hit(self, state_dict):
         lane_map = self.map 
         guard_hits = []
-        is_conatined = False        # TODO: Handle this
+        is_contained = False        # TODO: Handle this
         for agent_id in state_dict:
             agent:BaseAgent = self.agent_dict[agent_id]
             agent_state, agent_mode = state_dict[agent_id]
@@ -185,6 +185,8 @@ class Scenario:
         trace_length = int(len(list(node.trace.values())[0])/2)
         guard_hits = []
         guard_hit_bool = False
+
+        # TODO: can add parallalization for this loop
         for idx in range(0,trace_length):
             # For each trace, check with the guard to see if there's any possible transition
             # Store all possible transition in a list
@@ -206,26 +208,29 @@ class Scenario:
         reset_idx_dict = {}
         for hits, all_agent_state, hit_idx in guard_hits:
             for agent_id, guard_list, reset_list in hits:
-                dest,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
+                dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
                 if agent_id not in reset_dict:
                     reset_dict[agent_id] = {}
                     reset_idx_dict[agent_id] = {}
-                if dest not in reset_dict[agent_id]:
-                    reset_dict[agent_id][dest] = []
-                    reset_idx_dict[agent_id][dest] = []
-                reset_dict[agent_id][dest].append(reset_rect)
-                reset_idx_dict[agent_id][dest].append(hit_idx)
-        
+                for dest in dest_list:
+                    if dest not in reset_dict[agent_id]:
+                        reset_dict[agent_id][dest] = []
+                        reset_idx_dict[agent_id][dest] = []
+                    reset_dict[agent_id][dest].append(reset_rect)
+                    reset_idx_dict[agent_id][dest].append(hit_idx)
+            
         # Combine reset rects and construct transitions
         for agent in reset_dict:
             for dest in reset_dict[agent]:
                 combined_rect = None 
                 for rect in reset_dict[agent][dest]:
+                    rect = np.array(rect)
                     if combined_rect is None:
                         combined_rect = rect 
                     else:
                         combined_rect[0,:] = np.minimum(combined_rect[0,:], rect[0,:])
                         combined_rect[1,:] = np.maximum(combined_rect[1,:], rect[1,:])
+                combined_rect = combined_rect.tolist()
                 min_idx = min(reset_idx_dict[agent][dest])
                 max_idx = max(reset_idx_dict[agent][dest])
                 transition = (agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx))
diff --git a/plotter/__init__.py b/plotter/__init__.py
new file mode 100644
index 00000000..5dbc287d
--- /dev/null
+++ b/plotter/__init__.py
@@ -0,0 +1,22 @@
+#                       _oo0oo_
+#                      o8888888o
+#                      88" . "88
+#                      (| -_- |)
+#                      0\  =  /0
+#                    ___/`---'\___
+#                  .' \\|     |// '.
+#                 / \\|||  :  |||// \
+#                / _||||| -:- |||||- \
+#               |   | \\\  -  /// |   |
+#               | \_|  ''\---/''  |_/ |
+#               \  .-\__  '-'  ___/-. /
+#             ___'. .'  /--.--\  `. .'___
+#          ."" '<  `.___\_<|>_/___.' >' "".
+#         | | :  `- \`.;`\ _ /`;.`/ - ` : | |
+#         \  \ `_.   \_ __\ /__ _/   .-` /  /
+#     =====`-.____`.___ \_____/___.-`___.-'=====
+#                       `=---='
+#
+#
+#     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#	Codes are far away from bugs with the protection
\ No newline at end of file
diff --git a/plotter/parser.py b/plotter/parser.py
new file mode 100644
index 00000000..e5567fbd
--- /dev/null
+++ b/plotter/parser.py
@@ -0,0 +1,38 @@
+"""
+This file consist parser code for DryVR reachtube output
+"""
+
+from typing import TextIO
+import re 
+
+class Parser:
+    def __init__(self, f: TextIO):
+        data = f.readlines()
+        curr_key = ""
+        self.data_dict = {}
+        i = 0
+        while i < len(data):
+            line = data[i]
+            if not re.match('^[-+0-9+.+0-9+e+0-9 ]+$', line):
+                self.data_dict[line] = []
+                curr_key = line
+                i += 1
+            else:
+                line_lower = data[i]
+                line_lower_list = line_lower.split(' ')
+                line_lower_list = [float(elem) for elem in line_lower_list]
+                line_upper = data[i+1]
+                line_upper_list = line_upper.split(' ')
+                line_upper_list = [float(elem) for elem in line_upper_list]
+                rect = [line_lower_list, line_upper_list]
+                self.data_dict[curr_key].append(rect)
+                i += 2
+                
+    
+    def get_all_data(self):
+        res = []
+        for key in self.data_dict:
+            res += self.data_dict[key]
+        return res
+
+    
\ No newline at end of file
diff --git a/plotter/plotter2D.py b/plotter/plotter2D.py
new file mode 100644
index 00000000..2c79b5fc
--- /dev/null
+++ b/plotter/plotter2D.py
@@ -0,0 +1,59 @@
+"""
+This file consist main plotter code for DryVR reachtube output
+"""
+
+import matplotlib.patches as patches
+import matplotlib.pyplot as plt
+import numpy as np 
+from typing import List 
+
+colors = ['red', 'green', 'blue', 'yellow', 'black']
+
+def plot(
+    data, 
+    x_dim: int = 0, 
+    y_dim_list: List[int] = [1], 
+    color = 'b', 
+    fig = None, 
+    x_lim = (float('inf'), -float('inf')), 
+    y_lim = (float('inf'), -float('inf'))
+):
+    if fig is None:
+        fig = plt.figure()
+    ax = plt.gca()
+    x_min, x_max = x_lim
+    y_min, y_max = y_lim
+    for rect in data:
+        lb = rect[0]
+        ub = rect[1]
+        for y_dim in y_dim_list:
+            rect_patch = patches.Rectangle((lb[x_dim], lb[y_dim]), ub[x_dim]-lb[x_dim], ub[y_dim]-lb[y_dim], color = color)
+            ax.add_patch(rect_patch)
+            x_min = min(lb[x_dim], x_min)
+            y_min = min(lb[y_dim], y_min)
+            x_max = max(ub[x_dim], x_max)
+            y_max = max(ub[y_dim], y_max)
+
+    ax.set_xlim([x_min-1, x_max+1])
+    ax.set_ylim([y_min-1, y_max+1])
+    return fig, (x_min, x_max), (y_min, y_max)
+
+def plot_tree(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1], color='b', fig = None):
+    if fig is None:
+        fig = plt.figure()
+    
+    queue = [root]
+    x_lim = (float('inf'), -float('inf')) 
+    y_lim = (float('inf'), -float('inf'))
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        trace = traces[agent_id]
+        data = []
+        for i in range(0,len(trace),2):
+            data.append([trace[i], trace[i+1]])
+        fig, x_lim, y_lim = plot(data, x_dim, y_dim_list, color, fig, x_lim, y_lim)
+
+        queue += node.child
+
+    return fig
\ No newline at end of file
diff --git a/plotter/plotter3D.py b/plotter/plotter3D.py
new file mode 100644
index 00000000..1ede38e0
--- /dev/null
+++ b/plotter/plotter3D.py
@@ -0,0 +1,68 @@
+# Plot polytope in 3d
+# Written by: Kristina Miller
+
+import numpy as np
+import matplotlib.pyplot as plt
+import mpl_toolkits.mplot3d as a3
+
+from scipy.spatial import ConvexHull
+import polytope as pc
+import pyvista as pv
+
+def plot3d(node, x_dim, y_dim, z_dim, ax = None):
+    if ax is None:
+        ax = pv.Plotter()
+    lower_bound = []
+    upper_bound = []
+    for key in sorted(node.lower_bound):
+        lower_bound.append(node.lower_bound[key])
+    for key in sorted(node.upper_bound):
+        upper_bound.append(node.upper_bound[key])
+
+    for i in range(min(len(lower_bound), len(upper_bound))):
+        lb = list(map(float, lower_bound[i]))
+        ub = list(map(float, upper_bound[i]))
+
+        box = [[lb[x_dim], lb[y_dim], lb[z_dim]],[ub[x_dim], ub[y_dim], ub[z_dim]]]
+        poly = pc.box2poly(np.array(box).T)
+        plot_polytope_3d(poly.A, poly.b, ax = ax, color = '#b3de69')
+
+def plot_polytope_3d(A, b, ax = None, color = 'red', trans = 0.2, edge = True):
+	if ax is None:
+		ax = pv.Plotter() 
+
+	poly = pc.Polytope(A = A, b = b)
+	vertices = pc.extreme(poly)
+	cloud = pv.PolyData(vertices)
+	volume = cloud.delaunay_3d()
+	shell = volume.extract_geometry()
+	ax.add_mesh(shell, opacity=trans, color=color)
+	if edge:
+		edges = shell.extract_feature_edges(20)
+		ax.add_mesh(edges, color="k", line_width=1)
+
+def plot_line_3d(start, end, ax = None, color = 'blue', line_width = 1):
+	if ax is None:
+		ax = pv.Plotter() 
+
+	a = start
+	b = end
+
+	# Preview how this line intersects this mesh
+	line = pv.Line(a, b)
+
+	ax.add_mesh(line, color=color, line_width=line_width)
+
+if __name__ == '__main__':
+	A = np.array([[-1, 0, 0],
+				  [1, 0, 0],
+				  [0, -1, 0],
+				  [0, 1, 0],
+				  [0, 0, -1],
+				  [0, 0, 1]])
+	b = np.array([[1], [1], [1], [1], [1], [1]])
+	b2 = np.array([[-1], [2], [-1], [2], [-1], [2]])
+	ax1 = a3.Axes3D(plt.figure())
+	plot_polytope_3d(A, b, ax = ax1, color = 'red')
+	plot_polytope_3d(A, b2, ax = ax1, color = 'green')
+	plt.show()
\ No newline at end of file
-- 
GitLab