From 515a3853f1173726e8561c7bcd7c83a59230c468 Mon Sep 17 00:00:00 2001
From: Yangge Li <li213@illinois.edu>
Date: Mon, 20 Jun 2022 11:50:17 -0500
Subject: [PATCH] add function to plot point in 3d plotter

---
 demo/ball_bounces.py                 | 28 ++++++++++++++--------------
 dryvr_plus_plus/plotter/plotter3D.py | 19 ++++++++++++-------
 2 files changed, 26 insertions(+), 21 deletions(-)

diff --git a/demo/ball_bounces.py b/demo/ball_bounces.py
index 6c81aebd..d952f629 100644
--- a/demo/ball_bounces.py
+++ b/demo/ball_bounces.py
@@ -4,15 +4,15 @@ from typing import List
 # from dryvr_plus_plus.scene_verifier.map.lane import Lane
 
 class BallMode(Enum):
-    # Any model should have at least one mode
+    # TODO: Any model should have at least one mode
     Normal = auto()
-    # The one mode of this automation is called "Normal" and auto assigns it an integer value.
+    # TODO: The one mode of this automation is called "Normal" and auto assigns it an integer value.
     # Ultimately for simple models we would like to write
     # E.g., Mode = makeMode(Normal, bounce,...)
 
 class LaneMode(Enum):
     Lane0 = auto()
-    # For now this is a dummy notion of Lane
+    # TODO: For now this is a dummy notion of Lane
 
 class State:
     '''Defines the state variables of the model
@@ -29,8 +29,8 @@ class State:
 
 def controller(ego:State, others:State):
     '''Computes the possible mode transitions'''
-    output = copy.deepcopy(this)
-    '''Ego and output variable names should be flexible but 
+    output = copy.deepcopy(ego)
+    '''TODO: Ego and output variable names should be flexible but 
     currently these are somehow harcoded with the sensor'''
     # Stores the prestate first
     if ego.x<0:
@@ -40,7 +40,7 @@ def controller(ego:State, others:State):
         output.vy = -ego.vy
         output.y=0
     if ego.x>20:
-        # Q. If I change this to ego.x >= 20 then the model does not work.
+        # TODO: Q. If I change this to ego.x >= 20 then the model does not work.
         # I suspect this is because the same transition can be take many, many times.
         # We need to figure out a clean solution
         output.vx = -ego.vx
@@ -48,10 +48,10 @@ def controller(ego:State, others:State):
     if ego.y>20:
         output.vy = -ego.vy
         output.y=20
-  '''  if ego.x - others[1].x < 1 and ego.y - others[1].y < 1:
+    '''  if ego.x - others[1].x < 1 and ego.y - others[1].y < 1:
         output.vy = -ego.vy
         output.vx = -ego.vx'''
-  # We would like to be able to write something like this, but currently not allowed. 
+  # TODO: We would like to be able to write something like this, but currently not allowed. 
     return output
 
 
@@ -64,7 +64,7 @@ from dryvr_plus_plus.plotter.plotter2D import *
 import plotly.graph_objects as go
 
 if __name__ == "__main__":
-    ball_controller = '/Users/mitras/Dpp/GraphGeneration/demo/ball_bounces.py'
+    ball_controller = 'ball_bounces.py'
     bouncingBall = Scenario()
     myball1 = BallAgent('red-ball', file_name=ball_controller)
     myball2 = BallAgent('green-ball', file_name=ball_controller)
@@ -73,10 +73,10 @@ if __name__ == "__main__":
     #
     tmp_map = SimpleMap3()
     bouncingBall.set_map(tmp_map)
-    # If there is no useful map, then we should not have to write this.
+    # TODO: If there is no useful map, then we should not have to write this.
     # Default to some empty map
     bouncingBall.set_sensor(FakeSensor4())
-    # There should be a way to default to some well-defined sensor
+    # TODO: There should be a way to default to some well-defined sensor
     # for any model, without having to write an explicit sensor
     bouncingBall.set_init(
         [
@@ -88,12 +88,12 @@ if __name__ == "__main__":
             (BallMode.Normal, LaneMode.Lane0)
         ]
     )
-    # WE should be able to initialize each of the balls separately
+    # TODO: WE should be able to initialize each of the balls separately
     # this may be the cause for the VisibleDeprecationWarning
-    # Longer term: We should initialize by writing expressions like "-2 \leq myball1.x \leq 5"
+    # TODO: Longer term: We should initialize by writing expressions like "-2 \leq myball1.x \leq 5"
     # "-2 \leq myball1.x + myball2.x \leq 5"
     traces = bouncingBall.simulate(20)
-    # There should be a print({traces}) function
+    # TODO: There should be a print({traces}) function
     fig = go.Figure()
     fig = plotly_simulation_anime(traces, tmp_map, fig)
     fig.show()
diff --git a/dryvr_plus_plus/plotter/plotter3D.py b/dryvr_plus_plus/plotter/plotter3D.py
index 1ede38e0..68a3a96f 100644
--- a/dryvr_plus_plus/plotter/plotter3D.py
+++ b/dryvr_plus_plus/plotter/plotter3D.py
@@ -2,10 +2,7 @@
 # Written by: Kristina Miller
 
 import numpy as np
-import matplotlib.pyplot as plt
-import mpl_toolkits.mplot3d as a3
 
-from scipy.spatial import ConvexHull
 import polytope as pc
 import pyvista as pv
 
@@ -40,6 +37,7 @@ def plot_polytope_3d(A, b, ax = None, color = 'red', trans = 0.2, edge = True):
 	if edge:
 		edges = shell.extract_feature_edges(20)
 		ax.add_mesh(edges, color="k", line_width=1)
+	return ax
 
 def plot_line_3d(start, end, ax = None, color = 'blue', line_width = 1):
 	if ax is None:
@@ -52,6 +50,13 @@ def plot_line_3d(start, end, ax = None, color = 'blue', line_width = 1):
 	line = pv.Line(a, b)
 
 	ax.add_mesh(line, color=color, line_width=line_width)
+	return ax
+
+def plot_point_3d(points, ax=None, color='blue', point_size = 100):
+    if ax is None:
+        ax = pv.Plotter()
+    ax.add_points(points, render_points_as_spheres=True, point_size = point_size, color = color)	
+    return ax
 
 if __name__ == '__main__':
 	A = np.array([[-1, 0, 0],
@@ -62,7 +67,7 @@ if __name__ == '__main__':
 				  [0, 0, 1]])
 	b = np.array([[1], [1], [1], [1], [1], [1]])
 	b2 = np.array([[-1], [2], [-1], [2], [-1], [2]])
-	ax1 = a3.Axes3D(plt.figure())
-	plot_polytope_3d(A, b, ax = ax1, color = 'red')
-	plot_polytope_3d(A, b2, ax = ax1, color = 'green')
-	plt.show()
\ No newline at end of file
+	fig = pv.Plotter()
+	fig = plot_polytope_3d(A, b, ax = fig, color = 'red')
+	fig = plot_polytope_3d(A, b2, ax = fig, color = 'green')
+	fig.show()
\ No newline at end of file
-- 
GitLab