Skip to content
Snippets Groups Projects
Commit 515a3853 authored by li213's avatar li213
Browse files

add function to plot point in 3d plotter

parent 9041ae32
No related branches found
No related tags found
No related merge requests found
......@@ -4,15 +4,15 @@ from typing import List
# from dryvr_plus_plus.scene_verifier.map.lane import Lane
class BallMode(Enum):
# Any model should have at least one mode
# TODO: Any model should have at least one mode
Normal = auto()
# The one mode of this automation is called "Normal" and auto assigns it an integer value.
# TODO: The one mode of this automation is called "Normal" and auto assigns it an integer value.
# Ultimately for simple models we would like to write
# E.g., Mode = makeMode(Normal, bounce,...)
class LaneMode(Enum):
Lane0 = auto()
# For now this is a dummy notion of Lane
# TODO: For now this is a dummy notion of Lane
class State:
'''Defines the state variables of the model
......@@ -29,8 +29,8 @@ class State:
def controller(ego:State, others:State):
'''Computes the possible mode transitions'''
output = copy.deepcopy(this)
'''Ego and output variable names should be flexible but
output = copy.deepcopy(ego)
'''TODO: Ego and output variable names should be flexible but
currently these are somehow harcoded with the sensor'''
# Stores the prestate first
if ego.x<0:
......@@ -40,7 +40,7 @@ def controller(ego:State, others:State):
output.vy = -ego.vy
output.y=0
if ego.x>20:
# Q. If I change this to ego.x >= 20 then the model does not work.
# TODO: Q. If I change this to ego.x >= 20 then the model does not work.
# I suspect this is because the same transition can be take many, many times.
# We need to figure out a clean solution
output.vx = -ego.vx
......@@ -48,10 +48,10 @@ def controller(ego:State, others:State):
if ego.y>20:
output.vy = -ego.vy
output.y=20
''' if ego.x - others[1].x < 1 and ego.y - others[1].y < 1:
''' if ego.x - others[1].x < 1 and ego.y - others[1].y < 1:
output.vy = -ego.vy
output.vx = -ego.vx'''
# We would like to be able to write something like this, but currently not allowed.
# TODO: We would like to be able to write something like this, but currently not allowed.
return output
......@@ -64,7 +64,7 @@ from dryvr_plus_plus.plotter.plotter2D import *
import plotly.graph_objects as go
if __name__ == "__main__":
ball_controller = '/Users/mitras/Dpp/GraphGeneration/demo/ball_bounces.py'
ball_controller = 'ball_bounces.py'
bouncingBall = Scenario()
myball1 = BallAgent('red-ball', file_name=ball_controller)
myball2 = BallAgent('green-ball', file_name=ball_controller)
......@@ -73,10 +73,10 @@ if __name__ == "__main__":
#
tmp_map = SimpleMap3()
bouncingBall.set_map(tmp_map)
# If there is no useful map, then we should not have to write this.
# TODO: If there is no useful map, then we should not have to write this.
# Default to some empty map
bouncingBall.set_sensor(FakeSensor4())
# There should be a way to default to some well-defined sensor
# TODO: There should be a way to default to some well-defined sensor
# for any model, without having to write an explicit sensor
bouncingBall.set_init(
[
......@@ -88,12 +88,12 @@ if __name__ == "__main__":
(BallMode.Normal, LaneMode.Lane0)
]
)
# WE should be able to initialize each of the balls separately
# TODO: WE should be able to initialize each of the balls separately
# this may be the cause for the VisibleDeprecationWarning
# Longer term: We should initialize by writing expressions like "-2 \leq myball1.x \leq 5"
# TODO: Longer term: We should initialize by writing expressions like "-2 \leq myball1.x \leq 5"
# "-2 \leq myball1.x + myball2.x \leq 5"
traces = bouncingBall.simulate(20)
# There should be a print({traces}) function
# TODO: There should be a print({traces}) function
fig = go.Figure()
fig = plotly_simulation_anime(traces, tmp_map, fig)
fig.show()
......
......@@ -2,10 +2,7 @@
# Written by: Kristina Miller
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as a3
from scipy.spatial import ConvexHull
import polytope as pc
import pyvista as pv
......@@ -40,6 +37,7 @@ def plot_polytope_3d(A, b, ax = None, color = 'red', trans = 0.2, edge = True):
if edge:
edges = shell.extract_feature_edges(20)
ax.add_mesh(edges, color="k", line_width=1)
return ax
def plot_line_3d(start, end, ax = None, color = 'blue', line_width = 1):
if ax is None:
......@@ -52,6 +50,13 @@ def plot_line_3d(start, end, ax = None, color = 'blue', line_width = 1):
line = pv.Line(a, b)
ax.add_mesh(line, color=color, line_width=line_width)
return ax
def plot_point_3d(points, ax=None, color='blue', point_size = 100):
if ax is None:
ax = pv.Plotter()
ax.add_points(points, render_points_as_spheres=True, point_size = point_size, color = color)
return ax
if __name__ == '__main__':
A = np.array([[-1, 0, 0],
......@@ -62,7 +67,7 @@ if __name__ == '__main__':
[0, 0, 1]])
b = np.array([[1], [1], [1], [1], [1], [1]])
b2 = np.array([[-1], [2], [-1], [2], [-1], [2]])
ax1 = a3.Axes3D(plt.figure())
plot_polytope_3d(A, b, ax = ax1, color = 'red')
plot_polytope_3d(A, b2, ax = ax1, color = 'green')
plt.show()
\ No newline at end of file
fig = pv.Plotter()
fig = plot_polytope_3d(A, b, ax = fig, color = 'red')
fig = plot_polytope_3d(A, b2, ax = fig, color = 'green')
fig.show()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment