diff --git a/.gitignore b/.gitignore index 03a577b0b713ed8adadda7f6f583e58774f99186..682e59a8b61966549568fdaca3cccc0f809f50fe 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ venv/ .VSCodeCounter/ dist/dryvr_plus_plus-0.1-py3.8.egg tmp/ -dist/ \ No newline at end of file +dist/ +demo/quadrotor/output.json \ No newline at end of file diff --git a/README.md b/README.md index 34bed013c4507fc2ca8b9d303e03fa765e2b9fcf..19f31305b170fb9e8ec903a93d306178925a031a 100644 --- a/README.md +++ b/README.md @@ -16,19 +16,14 @@ pip install -r requirements.txt ``` ## Examples -The package comes with two controller examples inside folder ```demo/``` -- The first example consists a scenario with two vehicles. The second vehicle will brake and stop when it detect the first vehicle in front. The first example can be run by using command +The package comes with several examples in the ```demo/``` folder +- Run examples as: ``` python3 demo1.py ``` -- The second example consists a scenario with two vehicles, which can perform lane switch based on their relative position. -The second example can be run using command - -``` -python3 demo2.py -``` +Read the comments in ```ball_bounces.py``` to learn how to create new agents and scenarios. More detailed tutorials will be provided later. ## Package Structure diff --git a/demo/F16/F16_waypoint_scene.py b/demo/F16/F16_waypoint_scene.py new file mode 100644 index 0000000000000000000000000000000000000000..8f557248fd28fe795b5038e35b9c42f4e4684ee1 --- /dev/null +++ b/demo/F16/F16_waypoint_scene.py @@ -0,0 +1,154 @@ +''' +F16 scenario Sayan Mitra +Derived from Stanley Bak's version +''' + + +from typing import Tuple, List + +import numpy as np +from scipy.integrate import ode +from enum import Enum, auto + +from demo.F16.aerobench.run_f16_sim import F16Agent +from dryvr_plus_plus import LaneMap +from dryvr_plus_plus import Scenario + + +import copy +import math +from numpy import deg2rad +import matplotlib.pyplot as plt +from aerobench.run_f16_sim import run_f16_sim +from aerobench.visualize import plot +from waypoint_autopilot import WaypointAutopilot + + +class F16Mode(Enum): + '''Defines the discrete modes of a single agent''' + NORMAL = auto() + # TODO: The one mode of this automation is called "NORMAL" and auto assigns it an integer value. + + +class State: + '''Defines the state variables of the model + Both discrete and continuous variables. + Initial values defined here do not matter. + ''' + mode: F16Mode + ### Continuous variable initial conditions ### + power = 9 # engine power level (0-10) + + # Default alpha & beta + alpha = deg2rad(2.1215) # Trim Angle of Attack (rad) + beta = 0 # Side slip angle (rad) + + # Initial Attitude + alt = 3800 # altitude (ft) + vt = 540 # initial velocity (ft/sec) + phi = 0 # Roll angle from wings level (rad) + theta = 0 # Pitch angle from nose level (rad) + psi = math.pi / 8 # Yaw angle from North (rad) + + p = 0 + q = 0 + r = 0 + pn = 0 + pe = 0 + + # Build Initial Condition Vectors + # state = [vt, alpha, beta, phi, theta, psi, P, Q, R, pn, pe, h, pow] + init = [vt, alpha, beta, phi, theta, psi, 0, 0, 0, 0, 0, alt, power] + def __init__(self, vt, alpha, beta, phi, theta, psi, p, q, r, pn, pe, alt, power, mode:F16Mode): + pass + +def controller(ego:State, others:State): + '''Computes the possible mode transitions + For now this is an empty controller function. + Coming soon. Waypoint transitions. Platooning.''' + output = copy.deepcopy(ego) + return output + +def main(): + ''' The main function defines and simulates a scene. + Defining and using a scenario involves the following 5 easy steps: + 1. creating a basic scenario object with Scenario() + 2. defining the agents that will populate the object. Here a single F16 agent. + 3. adding the agents to the scenario using .add_agent() + 4. initializing the agents for this scenario. + Note that agents are only initialized *in* a scenario, not individually outside a scenario + 5. genetating the simulation traces or computing the reachable states + ''' + F16waypointScene = Scenario() + F16Controller = './F16_waypoint_scene.py' + + # Resume here. This next line is the problem + # Fighter1 = F16Agent('Fighter1', file_name=F16Controller) + + ### Initial Conditions ### + power = 9 # engine power level (0-10) + + # Default alpha & beta + alpha = deg2rad(2.1215) # Trim Angle of Attack (rad) + beta = 0 # Side slip angle (rad) + + # Initial Attitude + alt = 3800 # altitude (ft) + vt = 540 # initial velocity (ft/sec) + phi = 0 # Roll angle from wings level (rad) + theta = 0 # Pitch angle from nose level (rad) + psi = math.pi/8 # Yaw angle from North (rad) + + # Build Initial Condition Vectors + # state = [vt, alpha, beta, phi, theta, psi, p, q, r, pn, pe, h, pow] + init = [vt, alpha, beta, phi, theta, psi, 0, 0, 0, 0, 0, alt, power] + tmax = 70 # simulation time + + # make waypoint list + e_pt = 1000 + n_pt = 3000 + h_pt = 4000 + + waypoints = [[e_pt, n_pt, h_pt], + [e_pt + 2000, n_pt + 5000, h_pt - 100], + [e_pt - 2000, n_pt + 15000, h_pt - 250], + [e_pt - 500, n_pt + 25000, h_pt]] + + ap = WaypointAutopilot(waypoints, stdout=True) + + step = 1/30 + extended_states = True + '''Main call to simulation''' + res = run_f16_sim(init, tmax, ap, step=step, extended_states=extended_states, integrator_str='rk45') + + print(f"Simulation Completed in {round(res['runtime'], 2)} seconds (extended_states={extended_states})") + + plot.plot_single(res, 'alt', title='Altitude (ft)') + filename = 'alt.png' + plt.savefig(filename) + print(f"Made {filename}") + + plot.plot_overhead(res, waypoints=waypoints) + filename = 'overhead.png' + plt.savefig(filename) + print(f"Made {filename}") + + plot.plot_attitude(res) + filename = 'attitude.png' + plt.savefig(filename) + print(f"Made {filename}") + + # plot inner loop controls + references + plot.plot_inner_loop(res) + filename = 'inner_loop.png' + plt.savefig(filename) + print(f"Made {filename}") + + # plot outer loop controls + references + plot.plot_outer_loop(res) + filename = 'outer_loop.png' + plt.savefig(filename) + print(f"Made {filename}") + +if __name__ == '__main__': + main() diff --git a/demo/F16/README.md b/demo/F16/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b668344ea8baee8c6a13ff94b6d1f3dbb9c7f6bf --- /dev/null +++ b/demo/F16/README.md @@ -0,0 +1,6 @@ +# Adaptation of F16 model + +This is an adaptation of the Python F16 model from [Stanley Bak](https://github.com/stanleybak/AeroBenchVVPython). +All the files are copied to make this example a standalone executable. Requires you to add ```/Users/mitras/Dpp/GraphGeneration/demo/F16``` to PYTHONPATH. + +Test by ```python3 F16_waypoint_scene.py``` diff --git a/demo/F16/aerobench/examples/waypoint/F16_waypoint_scene.py b/demo/F16/aerobench/examples/waypoint/F16_waypoint_scene.py new file mode 100644 index 0000000000000000000000000000000000000000..8984fc182bcfbf622e6bd3a7f989986003fd161a --- /dev/null +++ b/demo/F16/aerobench/examples/waypoint/F16_waypoint_scene.py @@ -0,0 +1,96 @@ +''' +Sayan Mitra +Derived from Stanley Bak's version +''' + + +from typing import Tuple, List + +import numpy as np +from scipy.integrate import ode + +from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent +from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap +from dryvr_plus_plus.scene_verifier.code_parser.pythonparser import EmptyAst + + +import math + +from numpy import deg2rad +import matplotlib.pyplot as plt + +from aerobench.run_f16_sim import run_f16_sim + +from aerobench.visualize import plot + +from waypoint_autopilot import WaypointAutopilot + +def main(): + 'main function' + + ### Initial Conditions ### + power = 9 # engine power level (0-10) + + # Default alpha & beta + alpha = deg2rad(2.1215) # Trim Angle of Attack (rad) + beta = 0 # Side slip angle (rad) + + # Initial Attitude + alt = 3800 # altitude (ft) + vt = 540 # initial velocity (ft/sec) + phi = 0 # Roll angle from wings level (rad) + theta = 0 # Pitch angle from nose level (rad) + psi = math.pi/8 # Yaw angle from North (rad) + + # Build Initial Condition Vectors + # state = [vt, alpha, beta, phi, theta, psi, P, Q, R, pn, pe, h, pow] + init = [vt, alpha, beta, phi, theta, psi, 0, 0, 0, 0, 0, alt, power] + tmax = 70 # simulation time + + # make waypoint list + e_pt = 1000 + n_pt = 3000 + h_pt = 4000 + + waypoints = [[e_pt, n_pt, h_pt], + [e_pt + 2000, n_pt + 5000, h_pt - 100], + [e_pt - 2000, n_pt + 15000, h_pt - 250], + [e_pt - 500, n_pt + 25000, h_pt]] + + ap = WaypointAutopilot(waypoints, stdout=True) + + step = 1/30 + extended_states = True + res = run_f16_sim(init, tmax, ap, step=step, extended_states=extended_states, integrator_str='rk45') + + print(f"Simulation Completed in {round(res['runtime'], 2)} seconds (extended_states={extended_states})") + + plot.plot_single(res, 'alt', title='Altitude (ft)') + filename = 'alt.png' + plt.savefig(filename) + print(f"Made {filename}") + + plot.plot_overhead(res, waypoints=waypoints) + filename = 'overhead.png' + plt.savefig(filename) + print(f"Made {filename}") + + plot.plot_attitude(res) + filename = 'attitude.png' + plt.savefig(filename) + print(f"Made {filename}") + + # plot inner loop controls + references + plot.plot_inner_loop(res) + filename = 'inner_loop.png' + plt.savefig(filename) + print(f"Made {filename}") + + # plot outer loop controls + references + plot.plot_outer_loop(res) + filename = 'outer_loop.png' + plt.savefig(filename) + print(f"Made {filename}") + +if __name__ == '__main__': + main() diff --git a/demo/F16/aerobench/examples/waypoint/run_u_turn.py b/demo/F16/aerobench/examples/waypoint/run_u_turn.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f77701f1497be07008a68aeb41bc34d3286825 --- /dev/null +++ b/demo/F16/aerobench/examples/waypoint/run_u_turn.py @@ -0,0 +1,81 @@ +''' +Stanley Bak + +should match 'u_turn' scenario from matlab version +''' + +import math + +from numpy import deg2rad +import matplotlib.pyplot as plt + +from aerobench.run_f16_sim import run_f16_sim + +from aerobench.visualize import plot + +from waypoint_autopilot import WaypointAutopilot + +def main(): + 'main function' + + ### Initial Conditions ### + power = 9 # engine power level (0-10) + + # Default alpha & beta + alpha = deg2rad(2.1215) # Trim Angle of Attack (rad) + beta = 0 # Side slip angle (rad) + + # Initial Attitude + alt = 1500 # altitude (ft) + vt = 540 # initial velocity (ft/sec) + phi = 0 # Roll angle from wings level (rad) + theta = 0 # Pitch angle from nose level (rad) + psi = 0 # Yaw angle from North (rad) + + # Build Initial Condition Vectors + # state = [vt, alpha, beta, phi, theta, psi, P, Q, R, pn, pe, h, pow] + init = [vt, alpha, beta, phi, theta, psi, 0, 0, 0, 0, 0, alt, power] + tmax = 150 # simulation time + + # make waypoint list + waypoints = [[-5000, -7500, alt], + [-15000, -7500, alt], + [-20000, 0, alt+500]] + + ap = WaypointAutopilot(waypoints, stdout=True) + + step = 1/30 + extended_states = True + res = run_f16_sim(init, tmax, ap, step=step, extended_states=extended_states, integrator_str='rk45') + + print(f"Simulation Completed in {round(res['runtime'], 2)} seconds (extended_states={extended_states})") + + plot.plot_single(res, 'alt', title='Altitude (ft)') + filename = 'alt.png' + plt.savefig(filename) + print(f"Made {filename}") + + plot.plot_overhead(res, waypoints=waypoints) + filename = 'overhead.png' + plt.savefig(filename) + print(f"Made {filename}") + + plot.plot_attitude(res) + filename = 'attitude.png' + plt.savefig(filename) + print(f"Made {filename}") + + # plot inner loop controls + references + plot.plot_inner_loop(res) + filename = 'inner_loop.png' + plt.savefig(filename) + print(f"Made {filename}") + + # plot outer loop controls + references + plot.plot_outer_loop(res) + filename = 'outer_loop.png' + plt.savefig(filename) + print(f"Made {filename}") + +if __name__ == '__main__': + main() diff --git a/demo/F16/aerobench/examples/waypoint/run_waypoint.py b/demo/F16/aerobench/examples/waypoint/run_waypoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e2952a6b5034a24e045e8ffba2ade2db21a6f38b --- /dev/null +++ b/demo/F16/aerobench/examples/waypoint/run_waypoint.py @@ -0,0 +1,86 @@ +''' +Stanley Bak + +should match 'waypoint' scenario from matlab version +''' + +import math + +from numpy import deg2rad +import matplotlib.pyplot as plt + +from aerobench.run_f16_sim import run_f16_sim + +from aerobench.visualize import plot + +from waypoint_autopilot import WaypointAutopilot + +def main(): + 'main function' + + ### Initial Conditions ### + power = 9 # engine power level (0-10) + + # Default alpha & beta + alpha = deg2rad(2.1215) # Trim Angle of Attack (rad) + beta = 0 # Side slip angle (rad) + + # Initial Attitude + alt = 3800 # altitude (ft) + vt = 540 # initial velocity (ft/sec) + phi = 0 # Roll angle from wings level (rad) + theta = 0 # Pitch angle from nose level (rad) + psi = math.pi/8 # Yaw angle from North (rad) + + # Build Initial Condition Vectors + # state = [vt, alpha, beta, phi, theta, psi, P, Q, R, pn, pe, h, pow] + init = [vt, alpha, beta, phi, theta, psi, 0, 0, 0, 0, 0, alt, power] + tmax = 70 # simulation time + + # make waypoint list + e_pt = 1000 + n_pt = 3000 + h_pt = 4000 + + waypoints = [[e_pt, n_pt, h_pt], + [e_pt + 2000, n_pt + 5000, h_pt - 100], + [e_pt - 2000, n_pt + 15000, h_pt - 250], + [e_pt - 500, n_pt + 25000, h_pt]] + + ap = WaypointAutopilot(waypoints, stdout=True) + + step = 1/30 + extended_states = True + res = run_f16_sim(init, tmax, ap, step=step, extended_states=extended_states, integrator_str='rk45') + + print(f"Simulation Completed in {round(res['runtime'], 2)} seconds (extended_states={extended_states})") + + plot.plot_single(res, 'alt', title='Altitude (ft)') + filename = 'alt.png' + plt.savefig(filename) + print(f"Made {filename}") + + plot.plot_overhead(res, waypoints=waypoints) + filename = 'overhead.png' + plt.savefig(filename) + print(f"Made {filename}") + + plot.plot_attitude(res) + filename = 'attitude.png' + plt.savefig(filename) + print(f"Made {filename}") + + # plot inner loop controls + references + plot.plot_inner_loop(res) + filename = 'inner_loop.png' + plt.savefig(filename) + print(f"Made {filename}") + + # plot outer loop controls + references + plot.plot_outer_loop(res) + filename = 'outer_loop.png' + plt.savefig(filename) + print(f"Made {filename}") + +if __name__ == '__main__': + main() diff --git a/demo/F16/aerobench/examples/waypoint/waypoint_autopilot.py b/demo/F16/aerobench/examples/waypoint/waypoint_autopilot.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9575ff51826eb2452f965af05e6cf155fa2c11 --- /dev/null +++ b/demo/F16/aerobench/examples/waypoint/waypoint_autopilot.py @@ -0,0 +1,307 @@ +'''waypoint autopilot + +ported from matlab v2 +''' + +from math import pi, atan2, sqrt, sin, cos, asin + +import numpy as np + +from aerobench.highlevel.autopilot import Autopilot +from aerobench.util import StateIndex +from aerobench.lowlevel.low_level_controller import LowLevelController + +class WaypointAutopilot(Autopilot): + '''waypoint follower autopilot''' + + def __init__(self, waypoints, gain_str='old', stdout=False): + 'waypoints is a list of 3-tuples' + + self.stdout = stdout + self.waypoints = waypoints + self.waypoint_index = 0 + + # waypoint config + self.cfg_slant_range_threshold = 250 + + # default control when not waypoint tracking + self.cfg_u_ol_default = (0, 0, 0, 0.3) + + # control config + # Gains for speed control + self.cfg_k_vt = 0.25 + self.cfg_airspeed = 550 + + # Gains for altitude tracking + self.cfg_k_alt = 0.005 + self.cfg_k_h_dot = 0.02 + + # Gains for heading tracking + self.cfg_k_prop_psi = 5 + self.cfg_k_der_psi = 0.5 + + # Gains for roll tracking + self.cfg_k_prop_phi = 0.75 + self.cfg_k_der_phi = 0.5 + self.cfg_max_bank_deg = 65 # maximum bank angle setpoint + # v2 was 0.5, 0.9 + + # Ranges for Nz + self.cfg_max_nz_cmd = 4 + self.cfg_min_nz_cmd = -1 + + self.done_time = 0.0 + + llc = LowLevelController(gain_str=gain_str) + + Autopilot.__init__(self, 'Waypoint 1', llc=llc) + + def log(self, s): + 'print to terminal if stdout is true' + + if self.stdout: + print(s) + + def get_u_ref(self, _t, x_f16): + '''get the reference input signals''' + + if self.mode != "Done": + psi_cmd = self.get_waypoint_data(x_f16)[0] + + # Get desired roll angle given desired heading + phi_cmd = self.get_phi_to_track_heading(x_f16, psi_cmd) + ps_cmd = self.track_roll_angle(x_f16, phi_cmd) + + nz_cmd = self.track_altitude(x_f16) + throttle = self.track_airspeed(x_f16) + else: + # Waypoint Following complete: fly level. + throttle = self.track_airspeed(x_f16) + ps_cmd = self.track_roll_angle(x_f16, 0) + nz_cmd = self.track_altitude_wings_level(x_f16) + + # trim to limits + nz_cmd = max(self.cfg_min_nz_cmd, min(self.cfg_max_nz_cmd, nz_cmd)) + throttle = max(min(throttle, 1), 0) + + # Create reference vector + rv = [nz_cmd, ps_cmd, 0, throttle] + + return rv + + def track_altitude(self, x_f16): + 'get nz to track altitude, taking turning into account' + + h_cmd = self.waypoints[self.waypoint_index][2] + + h = x_f16[StateIndex.ALT] + phi = x_f16[StateIndex.PHI] + + # Calculate altitude error (positive => below target alt) + h_error = h_cmd - h + nz_alt = self.track_altitude_wings_level(x_f16) + nz_roll = get_nz_for_level_turn_ol(x_f16) + + if h_error > 0: + # Ascend wings level or banked + nz = nz_alt + nz_roll + elif abs(phi) < np.deg2rad(15): + # Descend wings (close enough to) level + nz = nz_alt + nz_roll + else: + # Descend in bank (no negative Gs) + nz = max(0, nz_alt + nz_roll) + + return nz + + def get_phi_to_track_heading(self, x_f16, psi_cmd): + 'get phi from psi_cmd' + + # PD Control on heading angle using phi_cmd as control + + # Pull out important variables for ease of use + psi = wrap_to_pi(x_f16[StateIndex.PSI]) + r = x_f16[StateIndex.R] + + # Calculate PD control + psi_err = wrap_to_pi(psi_cmd - psi) + + phi_cmd = psi_err * self.cfg_k_prop_psi - r * self.cfg_k_der_psi + + # Bound to acceptable bank angles: + max_bank_rad = np.deg2rad(self.cfg_max_bank_deg) + + phi_cmd = min(max(phi_cmd, -max_bank_rad), max_bank_rad) + + return phi_cmd + + def track_roll_angle(self, x_f16, phi_cmd): + 'get roll angle command (ps_cmd)' + + # PD control on roll angle using stability roll rate + + # Pull out important variables for ease of use + phi = x_f16[StateIndex.PHI] + p = x_f16[StateIndex.P] + + # Calculate PD control + ps = (phi_cmd-phi) * self.cfg_k_prop_phi - p * self.cfg_k_der_phi + + return ps + + def track_airspeed(self, x_f16): + 'get throttle command' + + vt_cmd = self.cfg_airspeed + + # Proportional control on airspeed using throttle + throttle = self.cfg_k_vt * (vt_cmd - x_f16[StateIndex.VT]) + + return throttle + + def track_altitude_wings_level(self, x_f16): + 'get nz to track altitude' + + i = self.waypoint_index if self.waypoint_index < len(self.waypoints) else -1 + + h_cmd = self.waypoints[i][2] + + vt = x_f16[StateIndex.VT] + h = x_f16[StateIndex.ALT] + + # Proportional-Derivative Control + h_error = h_cmd - h + gamma = get_path_angle(x_f16) + h_dot = vt * sin(gamma) # Calculated, not differentiated + + # Calculate Nz command + nz = self.cfg_k_alt*h_error - self.cfg_k_h_dot*h_dot + + return nz + + def is_finished(self, t, x_f16): + 'is the maneuver done?' + + rv = self.waypoint_index >= len(self.waypoints) and self.done_time + 5.0 < t + + return rv + + def advance_discrete_mode(self, t, x_f16): + ''' + advance the discrete state based on the current aircraft state. Returns True iff the discrete state + has changed. + ''' + + if self.waypoint_index < len(self.waypoints): + slant_range = self.get_waypoint_data(x_f16)[-1] + + if slant_range < self.cfg_slant_range_threshold: + self.waypoint_index += 1 + + if self.waypoint_index >= len(self.waypoints): + self.done_time = t + + premode = self.mode + + if self.waypoint_index >= len(self.waypoints): + self.mode = 'Done' + else: + self.mode = f'Waypoint {self.waypoint_index + 1}' + + rv = premode != self.mode + + if rv: + self.log(f"Waypoint transition {premode} -> {self.mode} at time {t}") + + return rv + + def get_waypoint_data(self, x_f16): + '''returns current waypoint data tuple based on the current waypoint: + + (heading, inclination, horiz_range, vert_range, slant_range) + + heading = heading to tgt, equivalent to psi (rad) + inclination = polar angle to tgt, equivalent to theta (rad) + horiz_range = horizontal range to tgt (ft) + vert_range = vertical range to tgt (ft) + slant_range = total range to tgt (ft) + ''' + + waypoint = self.waypoints[self.waypoint_index] + + e_pos = x_f16[StateIndex.POSE] + n_pos = x_f16[StateIndex.POSN] + alt = x_f16[StateIndex.ALT] + + delta = [waypoint[i] - [e_pos, n_pos, alt][i] for i in range(3)] + + _, inclination, slant_range = cart2sph(delta) + + heading = wrap_to_pi(pi/2 - atan2(delta[1], delta[0])) + + horiz_range = np.linalg.norm(delta[0:2]) + vert_range = np.linalg.norm(delta[2]) + + return heading, inclination, horiz_range, vert_range, slant_range + +def get_nz_for_level_turn_ol(x_f16): + 'get nz to do a level turn' + + # Pull g's to maintain altitude during bank based on trig + + # Calculate theta + phi = x_f16[StateIndex.PHI] + + if abs(phi): # if cos(phi) ~= 0, basically + nz = 1 / cos(phi) - 1 # Keeps plane at altitude + else: + nz = 0 + + return nz + +def get_path_angle(x_f16): + 'get the path angle gamma' + + alpha = x_f16[StateIndex.ALPHA] # AoA (rad) + beta = x_f16[StateIndex.BETA] # Sideslip (rad) + phi = x_f16[StateIndex.PHI] # Roll anle (rad) + theta = x_f16[StateIndex.THETA] # Pitch angle (rad) + + gamma = asin((cos(alpha)*sin(theta)- \ + sin(alpha)*cos(theta)*cos(phi))*cos(beta) - \ + (cos(theta)*sin(phi))*sin(beta)) + + return gamma + +def wrap_to_pi(psi_rad): + '''handle angle wrapping + + returns equivelent angle in range [-pi, pi] + ''' + + rv = psi_rad % (2 * pi) + + if rv > pi: + rv -= 2 * pi + + return rv + +def cart2sph(pt3d): + ''' + Cartesian to spherical coordinates + + returns az, elev, r + ''' + + x, y, z = pt3d + + h = sqrt(x*x + y*y) + r = sqrt(h*h + z*z) + + elev = atan2(z, h) + az = atan2(y, x) + + return az, elev, r + +if __name__ == '__main__': + print("Autopulot script not meant to be run directly.") diff --git a/demo/F16/aerobench/highlevel/autopilot.py b/demo/F16/aerobench/highlevel/autopilot.py new file mode 100644 index 0000000000000000000000000000000000000000..d369cdc82cbd69c46ffcb9fc6769b407ca78570e --- /dev/null +++ b/demo/F16/aerobench/highlevel/autopilot.py @@ -0,0 +1,97 @@ +''' +Stanley Bak +Autopilot State-Machine Logic + +There is a high-level advance_discrete_state() function, which checks if we should change the current discrete state, +and a get_u_ref(f16_state) function, which gets the reference inputs at the current discrete state. +''' + +import abc +from math import pi + +import numpy as np +from numpy import deg2rad + +from aerobench.lowlevel.low_level_controller import LowLevelController +from aerobench.util import Freezable + +class Autopilot(Freezable): + '''A container object for the hybrid automaton logic for a particular autopilot instance''' + + def __init__(self, init_mode, llc=None): + + assert isinstance(init_mode, str), 'init_mode should be a string' + + if llc is None: + # use default + llc = LowLevelController() + + self.llc = llc + self.xequil = llc.xequil + self.uequil = llc.uequil + + self.mode = init_mode # discrete state, this should be overwritten by subclasses + + self.freeze_attrs() + + def advance_discrete_mode(self, t, x_f16): + ''' + advance the discrete mode based on the current aircraft state. Returns True iff the discrete mode + has changed. It's also suggested to update self.mode to the current mode name. + ''' + + return False + + def is_finished(self, t, x_f16): + ''' + returns True if the simulation should stop (for example, after maneuver completes) + + this is called after advance_discrete_state + ''' + + return False + + @abc.abstractmethod + def get_u_ref(self, t, x_f16): + ''' + for the current discrete state, get the reference inputs signals. Override this one + in subclasses. + + returns four values per aircraft: Nz, ps, Ny_r, throttle + ''' + + return + + def get_checked_u_ref(self, t, x_f16): + ''' + for the current discrete state, get the reference inputs signals and check them against ctrl limits + ''' + + rv = np.array(self.get_u_ref(t, x_f16), dtype=float) + + assert rv.size % 4 == 0, "get_u_ref should return Nz, ps, Ny_r, throttle for each aircraft" + + for i in range(rv.size //4): + Nz, _ps, _Ny_r, _throttle = rv[4*i:4*(i+1)] + + l, u = self.llc.ctrlLimits.NzMin, self.llc.ctrlLimits.NzMax + assert l <= Nz <= u, f"autopilot commanded invalid Nz ({Nz}). Not in range [{l}, {u}]" + + return rv + +class FixedSpeedAutopilot(Autopilot): + '''Simple Autopilot that gives a fixed speed command using proportional control''' + + def __init__(self, setpoint, p_gain): + self.setpoint = setpoint + self.p_gain = p_gain + + init_mode = 'tracking speed' + Autopilot.__init__(self, init_mode) + + def get_u_ref(self, t, x_f16): + '''for the current discrete state, get the reference inputs signals''' + + x_dif = self.setpoint - x_f16[0] + + return 0, 0, 0, self.p_gain * x_dif diff --git a/demo/F16/aerobench/highlevel/controlled_f16.py b/demo/F16/aerobench/highlevel/controlled_f16.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ee115954ed94673ddbf8de28c9dd0eb7083937 --- /dev/null +++ b/demo/F16/aerobench/highlevel/controlled_f16.py @@ -0,0 +1,60 @@ +''' +Stanley Bak +Python Version of F-16 GCAS +ODE derivative code (controlled F16) +''' + +from math import sin, cos + +import numpy as np +from numpy import deg2rad + +from aerobench.lowlevel.subf16_model import subf16_model +from aerobench.lowlevel.low_level_controller import LowLevelController + +def controlled_f16(t, x_f16, u_ref, llc, f16_model='morelli', v2_integrators=False): + 'returns the LQR-controlled F-16 state derivatives and more' + + assert isinstance(x_f16, np.ndarray) + assert isinstance(llc, LowLevelController) + assert u_ref.size == 4 + + assert f16_model in ['stevens', 'morelli'], 'Unknown F16_model: {}'.format(f16_model) + + x_ctrl, u_deg = llc.get_u_deg(u_ref, x_f16) + + # Note: Control vector (u) for subF16 is in units of degrees + xd_model, Nz, Ny, _, _ = subf16_model(x_f16[0:13], u_deg, f16_model) + + if v2_integrators: + # integrators from matlab v2 model + ps = xd_model[6] * cos(xd_model[1]) + xd_model[8] * sin(xd_model[1]) + + Ny_r = Ny + xd_model[8] + else: + # Nonlinear (Actual): ps = p * cos(alpha) + r * sin(alpha) + ps = x_ctrl[4] * cos(x_ctrl[0]) + x_ctrl[5] * sin(x_ctrl[0]) + + # Calculate (side force + yaw rate) term + Ny_r = Ny + x_ctrl[5] + + xd = np.zeros((x_f16.shape[0],)) + xd[:len(xd_model)] = xd_model + + # integrators from low-level controller + start = len(xd_model) + end = start + llc.get_num_integrators() + int_der = llc.get_integrator_derivatives(t, x_f16, u_ref, Nz, ps, Ny_r) + xd[start:end] = int_der + + # Convert all degree values to radians for output + u_rad = np.zeros((7,)) # throt, ele, ail, rud, Nz_ref, ps_ref, Ny_r_ref + + u_rad[0] = u_deg[0] # throttle + + for i in range(1, 4): + u_rad[i] = deg2rad(u_deg[i]) + + u_rad[4:7] = u_ref[0:3] # inner-loop commands are 4-7 + + return xd, u_rad, Nz, ps, Ny_r diff --git a/demo/F16/aerobench/lowlevel/adc.py b/demo/F16/aerobench/lowlevel/adc.py new file mode 100644 index 0000000000000000000000000000000000000000..a70390eb25f9e91855152adf31799e17297e34e3 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/adc.py @@ -0,0 +1,37 @@ +''' +Stanley Bak +adc.py for F-16 model +''' + +from math import sqrt + +def adc(vt, alt): + '''converts velocity (vt) and altitude (alt) to mach number (amach) and dynamic pressure (qbar) + + See pages 63-65 of Stevens & Lewis, "Aircraft Control and Simulation", 2nd edition + ''' + + # vt = freestream air speed + + ro = 2.377e-3 + tfac = 1 - .703e-5 * alt + + if alt >= 35000: # in stratosphere + t = 390 + else: + t = 519 * tfac # 3 rankine per atmosphere (3 rankine per 1000 ft) + + # rho = freestream mass density + rho = ro * tfac**4.14 + + # a = speed of sound at the ambient conditions + # speed of sound in a fluid is the sqrt of the quotient of the modulus of elasticity over the mass density + a = sqrt(1.4 * 1716.3 * t) + + # amach = mach number + amach = vt / a + + # qbar = dynamic pressure + qbar = .5 * rho * vt * vt + + return amach, qbar diff --git a/demo/F16/aerobench/lowlevel/cl.py b/demo/F16/aerobench/lowlevel/cl.py new file mode 100644 index 0000000000000000000000000000000000000000..6bff35919a610317771baa8bc22ffcfacfe48973 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/cl.py @@ -0,0 +1,53 @@ +''' +Stanley Bak +F-16 GCAS in Python +cl function +''' + +import numpy as np + +from aerobench.util import fix, sign + +def cl(alpha, beta): + 'cl function' + + a = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], \ + [-.001, -.004, -.008, -.012, -.016, -.022, -.022, -.021, -.015, -.008, -.013, -.015], \ + [-.003, -.009, -.017, -.024, -.030, -.041, -.045, -.040, -.016, -.002, -.010, -.019], \ + [-.001, -.010, -.020, -.030, -.039, -.054, -.057, -.054, -.023, -.006, -.014, -.027], \ + [.000, -.010, -.022, -.034, -.047, -.060, -.069, -.067, -.033, -.036, -.035, -.035], \ + [.007, -.010, -.023, -.034, -.049, -.063, -.081, -.079, -.060, -.058, -.062, -.059], \ + [.009, -.011, -.023, -.037, -.050, -.068, -.089, -.088, -.091, -.076, -.077, -.076]], dtype=float).T + + s = .2 * alpha + k = fix(s) + + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + s = .2 * abs(beta) + m = fix(s) + if m == 0: + m = 1 + + if m >= 6: + m = 5 + + db = s - m + n = m + fix(1.1 * sign(db)) + l = l + 3 + k = k + 3 + m = m + 1 + n = n + 1 + t = a[k-1, m-1] + u = a[k-1, n-1] + v = t + abs(da) * (a[l-1, m-1] - t) + w = u + abs(da) * (a[l-1, n-1] - u) + dum = v + (w - v) * abs(db) + + return dum * sign(beta) diff --git a/demo/F16/aerobench/lowlevel/clf16.py b/demo/F16/aerobench/lowlevel/clf16.py new file mode 100644 index 0000000000000000000000000000000000000000..a09addfe2445456f98efb89c6fbcee662be911e7 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/clf16.py @@ -0,0 +1,85 @@ +''' +Stanley Bak +clf16.py for F-16 model + +This is the objective function for finding the trim condition of the initial states +''' + +from math import asin, sin + +from tgear import tgear +from conf16 import conf16 +from subf16_model import subf16_model + +def clf16(s, x, u, const, model='stevens', adjust_cy=True): + ''' + objective function of the optimization to find the trim conditions + + x and u get modified in-place + returns the cost + ''' + + _, singam, _, _, tr, _, _, _, thetadot, _, _, orient = const + gamm = asin(singam) + + if len(s) == 3: + u[0] = s[0] + u[1] = s[1] + x[1] = s[2] + else: + u[0] = s[0] + u[1] = s[1] + u[2] = s[2] + u[3] = s[3] + x[1] = s[4] + x[3] = s[5] + x[4] = s[6] + + # + # Get the current power and constraints + # + x[12] = tgear(u[0]) + [x, u] = conf16(x, u, const) + + # we just want the derivative + subf16 = lambda x, u: subf16_model(x, u, model, adjust_cy)[0] + + xd = subf16(x, u) + + # + # Steady Level flight + # + if orient == 1: + r = 100.0*(xd[0]**2 + xd[1]**2 + xd[2]**2 + xd[6]**2 + xd[7]**2 + xd[8]**2) + + # + # Steady Climb + # + if orient == 2: + r = 500.0*(xd[11]-x[0]*sin(gamm))**2 + xd[0]**2 + 100.0*(xd[1]**2 + xd[2]**2) + \ + 10.0*(xd[6]**2 + xd[7]**2 + xd[8]**2) + + + + # + # Coord Turn + # + if orient == 3: + r = xd[0]*xd[0] + 100.0 * (xd[1] * xd[1] + xd[2]*xd[2] + xd[11]*xd[11]) + 10.0*(xd[6]*xd[6] + \ + xd[7]*xd[7]+xd[8]*xd[8]) + 500.0*(xd[5] - tr)**2 + + # + # Pitch Pull Up + # + + + if orient == 4: + r = 500.0*(xd[4]-thetadot)**2 + xd[0]**2 + 100.0*(xd[1]**2 + xd[2]**2) + 10.0*(xd[6]**2 + xd[7]**2 + xd[8]**2) + + # + # Scale r if it is less than 1 + # + if r < 1.0: + r = r**0.5 + + return r diff --git a/demo/F16/aerobench/lowlevel/cm.py b/demo/F16/aerobench/lowlevel/cm.py new file mode 100644 index 0000000000000000000000000000000000000000..4c00a17596797b95076272f079f7e81502396b30 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/cm.py @@ -0,0 +1,50 @@ +''' +Stanley Bak +F-16 GCAS Python +''' + +import numpy as np +from aerobench.util import fix, sign + +def cm(alpha, el): + 'cm function' + + a = np.array([[.205, .168, .186, .196, .213, .251, .245, .238, .252, .231, .198, .192], \ + [.081, .077, .107, .110, .110, .141, .127, .119, .133, .108, .081, .093], \ + [-.046, -.020, -.009, -.005, -.006, .010, .006, -.001, .014, .000, -.013, .032], \ + [-.174, -.145, -.121, -.127, -.129, -.102, -.097, -.113, -.087, -.084, -.069, -.006], \ + [-.259, -.202, -.184, -.193, -.199, -.150, -.160, -.167, -.104, -.076, -.041, -.005]], dtype=float).T + + s = .2 * alpha + k = fix(s) + + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + s = el / 12 + m = fix(s) + + if m <= -2: + m = -1 + + if m >= 2: + m = 1 + + de = s - m + n = m + fix(1.1 * sign(de)) + k = k + 3 + l = l + 3 + m = m + 3 + n = n + 3 + t = a[k-1, m-1] + u = a[k-1, n-1] + v = t + abs(da) * (a[l-1, m-1] - t) + w = u + abs(da) * (a[l-1, n-1] - u) + + return v + (w - v) * abs(de) + diff --git a/demo/F16/aerobench/lowlevel/cn.py b/demo/F16/aerobench/lowlevel/cn.py new file mode 100644 index 0000000000000000000000000000000000000000..b06e3729c43572920230f410ef80555c3b766ee8 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/cn.py @@ -0,0 +1,54 @@ +''' +Stanley Bak +F16 GCAS in Python +cn function +''' + +import numpy as np +from aerobench.util import fix, sign + +def cn(alpha, beta): + 'cn function' + + a = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], \ + [.018, .019, .018, .019, .019, .018, .013, .007, .004, -.014, -.017, -.033], \ + [.038, .042, .042, .042, .043, .039, .030, .017, .004, -.035, -.047, -.057], \ + [.056, .057, .059, .058, .058, .053, .032, .012, .002, -.046, -.071, -.073], \ + [.064, .077, .076, .074, .073, .057, .029, .007, .012, -.034, -.065, -.041], \ + [.074, .086, .093, .089, .080, .062, .049, .022, .028, -.012, -.002, -.013], \ + [.079, .090, .106, .106, .096, .080, .068, .030, .064, .015, .011, -.001]], dtype=float).T + + s = .2 * alpha + k = fix(s) + + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + s = .2 * abs(beta) + m = fix(s) + + if m == 0: + m = 1 + + if m >= 6: + m = 5 + + db = s - m + n = m + fix(1.1 * sign(db)) + l = l + 3 + k = k + 3 + m = m + 1 + n = n + 1 + t = a[k-1, m-1] + u = a[k-1, n-1] + + v = t + abs(da) * (a[l-1, m-1] - t) + w = u + abs(da) * (a[l-1, n-1] - u) + dum = v + (w - v) * abs(db) + + return dum * sign(beta) diff --git a/demo/F16/aerobench/lowlevel/conf16.py b/demo/F16/aerobench/lowlevel/conf16.py new file mode 100644 index 0000000000000000000000000000000000000000..f322a6def618aa96f0b5163f1c73cd4ca75cd83f --- /dev/null +++ b/demo/F16/aerobench/lowlevel/conf16.py @@ -0,0 +1,58 @@ +''' +Stanley Bak +Python F-16 + +Apply constraints to x variable +used when finding trim conditions +''' + +from math import sin, cos, asin +from tgear import tgear + +def conf16(x, u, const): + 'apply constraints to x' + + radgam, singam, rr, pr, tr, phi, cphi, sphi, thetadot, coord, stab, orient = const + gamm = asin(singam) + + # + # Steady Level Flight + # + if orient == 1: + x[3] = phi # Phi + x[4] = x[1] # Theta + x[6] = rr # Roll Rate + x[7] = pr # Pitch Rate + x[8] = 0.0 # Yaw Rate + + # + # Steady Climb + # + if orient == 2: + x[3] = phi # Phi + x[4] = x[1] + radgam # Theta + x[6] = rr # Roll Rate + x[7] = pr # Pitch Rate + x[8] = 0.0 # Yaw Rate + + # + # orient=3 implies coordinated turn + # + if orient == 3: + x[6] = -tr * sin(x[4]) # Roll Rate + x[7] = tr * cos(x[4]) * sin(x[3]) # Pitch Rate + x[8] = tr * cos(x[4]) * cos(x[3]) # Yaw Rate + + # + # Pitch Pull Up + # + if orient == 4: + x[4] = x[1] # Theta = alpha + x[3] = phi # Phi + x[6] = rr # Roll Rate + x[7] = thetadot # Pitch Rate + x[8] = 0.0 # Yaw Rate + + x[12] = tgear(u[0]) + + return x, u diff --git a/demo/F16/aerobench/lowlevel/cx.py b/demo/F16/aerobench/lowlevel/cx.py new file mode 100644 index 0000000000000000000000000000000000000000..c23a96cb6ab44c14e48dce145e1ff9bb4179e8fd --- /dev/null +++ b/demo/F16/aerobench/lowlevel/cx.py @@ -0,0 +1,50 @@ +''' +Stanley Bak +Python F-16 GCAS +cx +''' + +import numpy as np + +from aerobench.util import fix, sign + +def cx(alpha, el): + 'cx definition' + + a = np.array([[-.099, -.081, -.081, -.063, -.025, .044, .097, .113, .145, .167, .174, .166], \ + [-.048, -.038, -.040, -.021, .016, .083, .127, .137, .162, .177, .179, .167], \ + [-.022, -.020, -.021, -.004, .032, .094, .128, .130, .154, .161, .155, .138], \ + [-.040, -.038, -.039, -.025, .006, .062, .087, .085, .100, .110, .104, .091], \ + [-.083, -.073, -.076, -.072, -.046, .012, .024, .025, .043, .053, .047, .040]], dtype=float).T + + s = .2 * alpha + k = fix(s) + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + s = el / 12 + m = fix(s) + if m <= -2: + m = -1 + + if m >= 2: + m = 1 + + de = s - m + n = m + fix(1.1 * sign(de)) + k = k + 3 + l = l + 3 + m = m + 3 + n = n + 3 + t = a[k-1, m-1] + u = a[k-1, n-1] + v = t + abs(da) * (a[l-1, m-1] - t) + w = u + abs(da) * (a[l-1, n-1] - u) + cxx = v + (w - v) * abs(de) + + return cxx diff --git a/demo/F16/aerobench/lowlevel/cy.py b/demo/F16/aerobench/lowlevel/cy.py new file mode 100644 index 0000000000000000000000000000000000000000..fac8c8c34ae7f1ea1307a9332e30774a146e760c --- /dev/null +++ b/demo/F16/aerobench/lowlevel/cy.py @@ -0,0 +1,9 @@ +''' +Stanley Bak +Python F-16 GCAS +''' + +def cy(beta, ail, rdr): + 'cy function' + + return -.02 * beta + .021 * (ail / 20) + .086 * (rdr / 30) diff --git a/demo/F16/aerobench/lowlevel/cz.py b/demo/F16/aerobench/lowlevel/cz.py new file mode 100644 index 0000000000000000000000000000000000000000..66f5769545f3fdcefb1c010dcd74483b9c0a53e5 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/cz.py @@ -0,0 +1,31 @@ +''' +Stanley Bak +Python F-16 GCAS +Cz function +''' + +import numpy as np +from aerobench.util import fix, sign + +def cz(alpha, beta, el): + 'cz function' + + a = np.array([.770, .241, -.100, -.415, -.731, -1.053, -1.355, -1.646, -1.917, -2.120, -2.248, -2.229], \ + dtype=float).T + + s = .2 * alpha + k = fix(s) + + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + l = l + 3 + k = k + 3 + s = a[k-1] + abs(da) * (a[l-1] - a[k-1]) + + return s * (1 - (beta / 57.3)**2) - .19 * (el / 25) diff --git a/demo/F16/aerobench/lowlevel/dampp.py b/demo/F16/aerobench/lowlevel/dampp.py new file mode 100644 index 0000000000000000000000000000000000000000..ffca449504156a1f2a1d50f98eb2b207c114d127 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/dampp.py @@ -0,0 +1,42 @@ +''' +Stanley Bak +F16 GCAS in Python +dampp function +''' + +import numpy as np +from aerobench.util import fix, sign + +def dampp(alpha): + 'dampp functon' + + a = np.array([[-.267, -.110, .308, 1.34, 2.08, 2.91, 2.76, 2.05, 1.50, 1.49, 1.83, 1.21], \ + [.882, .852, .876, .958, .962, .974, .819, .483, .590, 1.21, -.493, -1.04], \ + [-.108, -.108, -.188, .110, .258, .226, .344, .362, .611, .529, .298, -2.27], \ + [-8.80, -25.8, -28.9, -31.4, -31.2, -30.7, -27.7, -28.2, -29.0, -29.8, -38.3, -35.3], \ + [-.126, -.026, .063, .113, .208, .230, .319, .437, .680, .100, .447, -.330], \ + [-.360, -.359, -.443, -.420, -.383, -.375, -.329, -.294, -.230, -.210, -.120, -.100], \ + [-7.21, -.540, -5.23, -5.26, -6.11, -6.64, -5.69, -6.00, -6.20, -6.40, -6.60, -6.00], \ + [-.380, -.363, -.378, -.386, -.370, -.453, -.550, -.582, -.595, -.637, -1.02, -.840], \ + [.061, .052, .052, -.012, -.013, -.024, .050, .150, .130, .158, .240, .150]], dtype=float).T + + s = .2 * alpha + k = fix(s) + + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + k = k + 3 + l = l + 3 + + d = np.zeros((9,)) + + for i in range(9): + d[i] = a[k-1, i] + abs(da) * (a[l-1, i] - a[k-1, i]) + + return d diff --git a/demo/F16/aerobench/lowlevel/dlda.py b/demo/F16/aerobench/lowlevel/dlda.py new file mode 100644 index 0000000000000000000000000000000000000000..1f6b424b85147182a2c6021978932c9f06d2071f --- /dev/null +++ b/demo/F16/aerobench/lowlevel/dlda.py @@ -0,0 +1,50 @@ +''' +Stanley Bak +F-16 GCAS in Python +dlda function +''' + +import numpy as np +from aerobench.util import fix, sign + +def dlda(alpha, beta): + 'dlda function' + + a = np.array([[-.041, -.052, -.053, -.056, -.050, -.056, -.082, -.059, -.042, -.038, -.027, -.017], \ + [-.041, -.053, -.053, -.053, -.050, -.051, -.066, -.043, -.038, -.027, -.023, -.016], \ + [-.042, -.053, -.052, -.051, -.049, -.049, -.043, -.035, -.026, -.016, -.018, -.014], \ + [-.040, -.052, -.051, -.052, -.048, -.048, -.042, -.037, -.031, -.026, -.017, -.012], \ + [-.043, -.049, -.048, -.049, -.043, -.042, -.042, -.036, -.025, -.021, -.016, -.011], \ + [-.044, -.048, -.048, -.047, -.042, -.041, -.020, -.028, -.013, -.014, -.011, -.010], \ + [-.043, -.049, -.047, -.045, -.042, -.037, -.003, -.013, -.010, -.003, -.007, -.008]], dtype=float).T + + s = .2 * alpha + k = fix(s) + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + s = .1 * beta + m = fix(s) + if m <= -3: + m = -2 + + if m >= 3: + m = 2 + + db = s - m + n = m + fix(1.1 * sign(db)) + l = l + 3 + k = k + 3 + m = m + 4 + n = n + 4 + t = a[k-1, m-1] + u = a[k-1, n-1] + v = t + abs(da) * (a[l-1, m-1] - t) + w = u + abs(da) * (a[l-1, n-1] - u) + + return v + (w - v) * abs(db) diff --git a/demo/F16/aerobench/lowlevel/dldr.py b/demo/F16/aerobench/lowlevel/dldr.py new file mode 100644 index 0000000000000000000000000000000000000000..8d75348462f36b83800feb01659e7477df40f327 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/dldr.py @@ -0,0 +1,52 @@ +''' +Stanley Bak +Python GCAS F - 16 +dldr function +''' + +import numpy as np +from aerobench.util import sign, fix + +def dldr(alpha, beta): + 'dldr function' + + a = np.array([[.005, .017, .014, .010, -.005, .009, .019, .005, -.000, -.005, -.011, .008], \ + [.007, .016, .014, .014, .013, .009, .012, .005, .000, .004, .009, .007], \ + [.013, .013, .011, .012, .011, .009, .008, .005, -.002, .005, .003, .005], \ + [.018, .015, .015, .014, .014, .014, .014, .015, .013, .011, .006, .001], \ + [.015, .014, .013, .013, .012, .011, .011, .010, .008, .008, .007, .003], \ + [.021, .011, .010, .011, .010, .009, .008, .010, .006, .005, .000, .001], \ + [.023, .010, .011, .011, .011, .010, .008, .010, .006, .014, .020, .000]], dtype=float).T + + s = .2 * alpha + k = fix(s) + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + s = .1 * beta + m = fix(s) + + if m <= -3: + m = -2 + + if m >= 3: + m = 2 + + db = s - m + n = m + fix(1.1 * sign(db)) + l = l + 3 + k = k + 3 + m = m + 4 + n = n + 4 + t = a[k-1, m-1] + u = a[k-1, n-1] + + v = t + abs(da) * (a[l-1, m-1] - t) + w = u + abs(da) * (a[l-1, n-1] - u) + + return v + (w - v) * abs(db) diff --git a/demo/F16/aerobench/lowlevel/dnda.py b/demo/F16/aerobench/lowlevel/dnda.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e623554415d52300ec68953f92aadc14debf78 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/dnda.py @@ -0,0 +1,51 @@ +''' +Stanley Bak +F16 GCAS in Python +dnda function +''' + +import numpy as np +from aerobench.util import fix, sign + +def dnda(alpha, beta): + 'dnda function' + + a = np.array([[.001, -.027, -.017, -.013, -.012, -.016, .001, .017, .011, .017, .008, .016], \ + [.002, -.014, -.016, -.016, -.014, -.019, -.021, .002, .012, .016, .015, .011], \ + [-.006, -.008, -.006, -.006, -.005, -.008, -.005, .007, .004, .007, .006, .006], \ + [-.011, -.011, -.010, -.009, -.008, -.006, .000, .004, .007, .010, .004, .010], \ + [-.015, -.015, -.014, -.012, -.011, -.008, -.002, .002, .006, .012, .011, .011], \ + [-.024, -.010, -.004, -.002, -.001, .003, .014, .006, -.001, .004, .004, .006], \ + [-.022, .002, -.003, -.005, -.003, -.001, -.009, -.009, -.001, .003, -.002, .001]], dtype=float).T + + s = .2 * alpha + k = fix(s) + + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + s = .1 * beta + m = fix(s) + if m <= -3: + m = -2 + + if m >= 3: + m = 2 + + db = s - m + n = m + fix(1.1 * sign(db)) + l = l + 3 + k = k + 3 + m = m + 4 + n = n + 4 + t = a[k-1, m-1] + u = a[k-1, n-1] + v = t + abs(da) * (a[l-1, m-1] - t) + w = u + abs(da) * (a[l-1, n-1] - u) + + return v + (w - v) * abs(db) diff --git a/demo/F16/aerobench/lowlevel/dndr.py b/demo/F16/aerobench/lowlevel/dndr.py new file mode 100644 index 0000000000000000000000000000000000000000..805f170958e9a55df40addd68e2ab00566de0538 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/dndr.py @@ -0,0 +1,49 @@ +''' +Stanley Bak +F16 GCAS in Python +dndr function +''' + +import numpy as np +from aerobench.util import fix, sign + +def dndr(alpha, beta): + 'dndr function' + + a = np.array([[-.018, -.052, -.052, -.052, -.054, -.049, -.059, -.051, -.030, -.037, -.026, -.013], \ + [-.028, -.051, -.043, -.046, -.045, -.049, -.057, -.052, -.030, -.033, -.030, -.008], \ + [-.037, -.041, -.038, -.040, -.040, -.038, -.037, -.030, -.027, -.024, -.019, -.013], \ + [-.048, -.045, -.045, -.045, -.044, -.045, -.047, -.048, -.049, -.045, -.033, -.016], \ + [-.043, -.044, -.041, -.041, -.040, -.038, -.034, -.035, -.035, -.029, -.022, -.009], \ + [-.052, -.034, -.036, -.036, -.035, -.028, -.024, -.023, -.020, -.016, -.010, -.014], \ + [-.062, -.034, -.027, -.028, -.027, -.027, -.023, -.023, -.019, -.009, -.025, -.010]], dtype=float).T + + s = .2 * alpha + k = fix(s) + if k <= -2: + k = -1 + + if k >= 9: + k = 8 + + da = s - k + l = k + fix(1.1 * sign(da)) + s = .1 * beta + m = fix(s) + if m <= -3: + m = -2 + + if m >= 3: + m = 2 + + db = s - m + n = m + fix(1.1 * sign(db)) + l = l + 3 + k = k + 3 + m = m + 4 + n = n + 4 + t = a[k-1, m-1] + u = a[k-1, n-1] + v = t + abs(da) * (a[l-1, m-1] - t) + w = u + abs(da) * (a[l-1, n-1] - u) + return v + (w - v) * abs(db) diff --git a/demo/F16/aerobench/lowlevel/low_level_controller.py b/demo/F16/aerobench/lowlevel/low_level_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..d472bf309937035d4cef72ba5c7fcf0de6c94400 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/low_level_controller.py @@ -0,0 +1,109 @@ +''' +Stanley Bak +Low-level flight controller +''' + +import numpy as np +from aerobench.util import Freezable + +class CtrlLimits(Freezable): + 'Control Limits' + + def __init__(self): + self.ThrottleMax = 1 # Afterburner on for throttle > 0.7 + self.ThrottleMin = 0 + self.ElevatorMaxDeg = 25 + self.ElevatorMinDeg = -25 + self.AileronMaxDeg = 21.5 + self.AileronMinDeg = -21.5 + self.RudderMaxDeg = 30 + self.RudderMinDeg = -30 + + self.NzMax = 6 + self.NzMin = -1 + + self.freeze_attrs() + +class LowLevelController(Freezable): + '''low level flight controller + ''' + + old_k_long = np.array([[-156.8801506723475, -31.037008068526642, -38.72983346216317]], dtype=float) + old_k_lat = np.array([[37.84483, -25.40956, -6.82876, -332.88343, -17.15997], + [-23.91233, 5.69968, -21.63431, 64.49490, -88.36203]], dtype=float) + + old_xequil = np.array([502.0, 0.0389, 0.0, 0.0, 0.0389, 0.0, 0.0, 0.0, \ + 0.0, 0.0, 0.0, 1000.0, 9.0567], dtype=float).transpose() + old_uequil = np.array([0.1395, -0.7496, 0.0, 0.0], dtype=float).transpose() + + def __init__(self, gain_str='old'): + # Hard coded LQR gain matrix from matlab version + + assert gain_str == 'old' + + # Longitudinal Gains + K_long = LowLevelController.old_k_long + K_lat = LowLevelController.old_k_lat + + self.K_lqr = np.zeros((3, 8)) + self.K_lqr[:1, :3] = K_long + self.K_lqr[1:, 3:] = K_lat + + # equilibrium points from BuildLqrControllers.py + self.xequil = LowLevelController.old_xequil + self.uequil = LowLevelController.old_uequil + + self.ctrlLimits = CtrlLimits() + + self.freeze_attrs() + + def get_u_deg(self, u_ref, f16_state): + 'get the reference commands for the control surfaces' + + # Calculate perturbation from trim state + x_delta = f16_state.copy() + x_delta[:len(self.xequil)] -= self.xequil + + ## Implement LQR Feedback Control + # Reorder states to match controller: + # [alpha, q, int_e_Nz, beta, p, r, int_e_ps, int_e_Ny_r] + x_ctrl = np.array([x_delta[i] for i in [1, 7, 13, 2, 6, 8, 14, 15]], dtype=float) + + # Initialize control vectors + u_deg = np.zeros((4,)) # throt, ele, ail, rud + + # Calculate control using LQR gains + u_deg[1:4] = np.dot(-self.K_lqr, x_ctrl) # Full Control + + # Set throttle as directed from output of getOuterLoopCtrl(...) + u_deg[0] = u_ref[3] + + # Add in equilibrium control + u_deg[0:4] += self.uequil + + ## Limit controls to saturation limits + ctrlLimits = self.ctrlLimits + + # Limit throttle from 0 to 1 + u_deg[0] = max(min(u_deg[0], ctrlLimits.ThrottleMax), ctrlLimits.ThrottleMin) + + # Limit elevator from -25 to 25 deg + u_deg[1] = max(min(u_deg[1], ctrlLimits.ElevatorMaxDeg), ctrlLimits.ElevatorMinDeg) + + # Limit aileron from -21.5 to 21.5 deg + u_deg[2] = max(min(u_deg[2], ctrlLimits.AileronMaxDeg), ctrlLimits.AileronMinDeg) + + # Limit rudder from -30 to 30 deg + u_deg[3] = max(min(u_deg[3], ctrlLimits.RudderMaxDeg), ctrlLimits.RudderMinDeg) + + return x_ctrl, u_deg + + def get_num_integrators(self): + 'get the number of integrators in the low-level controller' + + return 3 + + def get_integrator_derivatives(self, t, x_f16, u_ref, Nz, ps, Ny_r): + 'get the derivatives of the integrators in the low-level controller' + + return [Nz - u_ref[0], ps - u_ref[1], Ny_r - u_ref[2]] diff --git a/demo/F16/aerobench/lowlevel/morellif16.py b/demo/F16/aerobench/lowlevel/morellif16.py new file mode 100644 index 0000000000000000000000000000000000000000..4632b78cdcd91e06777a1f4680db9a44f8623770 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/morellif16.py @@ -0,0 +1,189 @@ +''' +Stanley Bak +F16 GCAS in Python + +Morelli dynamics (Polynomial interpolation) +''' + +def Morellif16(alpha, beta, de, da, dr, p, q, r, cbar, b, V, xcg, xcgref): + 'desc' + + #alpha=max(-10*pi/180,min(45*pi/180,alpha)) # bounds alpha between -10 deg and 45 deg + #beta = max( - 30 * pi / 180, min(30 * pi / 180, beta)) #bounds beta between -30 deg and 30 deg + #de = max( - 25 * pi / 180, min(25 * pi / 180, de)) #bounds elevator deflection between -25 deg and 25 deg + #da = max( - 21.5 * pi / 180, min(21.5 * pi / 180, da)) #bounds aileron deflection between -21.5 deg and 21.5 deg + #dr = max( - 30 * pi / 180, min(30 * pi / 180, dr)) #bounds rudder deflection between -30 deg and 30 deg + + # xcgref = 0.35 + #reference longitudinal cg position in Morelli f16 model + + + phat = p * b / (2 * V) + qhat = q * cbar / (2 * V) + rhat = r * b / (2 * V) + ## + a0 = -1.943367e-2 + a1 = 2.136104e-1 + a2 = -2.903457e-1 + a3 = -3.348641e-3 + a4 = -2.060504e-1 + a5 = 6.988016e-1 + a6 = -9.035381e-1 + + b0 = 4.833383e-1 + b1 = 8.644627 + b2 = 1.131098e1 + b3 = -7.422961e1 + b4 = 6.075776e1 + + c0 = -1.145916 + c1 = 6.016057e-2 + c2 = 1.642479e-1 + + d0 = -1.006733e-1 + d1 = 8.679799e-1 + d2 = 4.260586 + d3 = -6.923267 + + e0 = 8.071648e-1 + e1 = 1.189633e-1 + e2 = 4.177702 + e3 = -9.162236 + + f0 = -1.378278e-1 + f1 = -4.211369 + f2 = 4.775187 + f3 = -1.026225e1 + f4 = 8.399763 + f5 = -4.354000e-1 + + g0 = -3.054956e1 + g1 = -4.132305e1 + g2 = 3.292788e2 + g3 = -6.848038e2 + g4 = 4.080244e2 + + h0 = -1.05853e-1 + h1 = -5.776677e-1 + h2 = -1.672435e-2 + h3 = 1.357256e-1 + h4 = 2.172952e-1 + h5 = 3.464156 + h6 = -2.835451 + h7 = -1.098104 + + i0 = -4.126806e-1 + i1 = -1.189974e-1 + i2 = 1.247721 + i3 = -7.391132e-1 + + j0 = 6.250437e-2 + j1 = 6.067723e-1 + j2 = -1.101964 + j3 = 9.100087 + j4 = -1.192672e1 + + k0 = -1.463144e-1 + k1 = -4.07391e-2 + k2 = 3.253159e-2 + k3 = 4.851209e-1 + k4 = 2.978850e-1 + k5 = -3.746393e-1 + k6 = -3.213068e-1 + + l0 = 2.635729e-2 + l1 = -2.192910e-2 + l2 = -3.152901e-3 + l3 = -5.817803e-2 + l4 = 4.516159e-1 + l5 = -4.928702e-1 + l6 = -1.579864e-2 + + m0 = -2.029370e-2 + m1 = 4.660702e-2 + m2 = -6.012308e-1 + m3 = -8.062977e-2 + m4 = 8.320429e-2 + m5 = 5.018538e-1 + m6 = 6.378864e-1 + m7 = 4.226356e-1 + + n0 = -5.19153 + n1 = -3.554716 + n2 = -3.598636e1 + n3 = 2.247355e2 + n4 = -4.120991e2 + n5 = 2.411750e2 + + o0 = 2.993363e-1 + o1 = 6.594004e-2 + o2 = -2.003125e-1 + o3 = -6.233977e-2 + o4 = -2.107885 + o5 = 2.141420 + o6 = 8.476901e-1 + + p0 = 2.677652e-2 + p1 = -3.298246e-1 + p2 = 1.926178e-1 + p3 = 4.013325 + p4 = -4.404302 + + q0 = -3.698756e-1 + q1 = -1.167551e-1 + q2 = -7.641297e-1 + + r0 = -3.348717e-2 + r1 = 4.276655e-2 + r2 = 6.573646e-3 + r3 = 3.535831e-1 + r4 = -1.373308 + r5 = 1.237582 + r6 = 2.302543e-1 + r7 = -2.512876e-1 + r8 = 1.588105e-1 + r9 = -5.199526e-1 + + s0 = -8.115894e-2 + s1 = -1.156580e-2 + s2 = 2.514167e-2 + s3 = 2.038748e-1 + s4 = -3.337476e-1 + s5 = 1.004297e-1 + + ## + Cx0 = a0 + a1 * alpha + a2 * de**2 + a3 * de + a4 * alpha * de + a5 * alpha**2 + a6 * alpha**3 + Cxq = b0 + b1 * alpha + b2 * alpha**2 + b3 * alpha**3 + b4 * alpha**4 + Cy0 = c0 * beta + c1 * da + c2 * dr + Cyp = d0 + d1 * alpha + d2 * alpha**2 + d3 * alpha**3 + Cyr = e0 + e1 * alpha + e2 * alpha**2 + e3 * alpha**3 + Cz0 = (f0 + f1 * alpha + f2 * alpha**2 + f3 * alpha**3 + f4 * alpha**4) * (1 - beta**2) + f5 * de + Czq = g0 + g1 * alpha + g2 * alpha**2 + g3 * alpha**3 + g4 * alpha**4 + Cl0 = h0 * beta + h1 * alpha * beta + h2 * alpha**2 * beta + h3 * beta**2 + h4 * alpha * beta**2 + h5 * \ + alpha**3 * beta + h6 * alpha**4 * beta + h7 * alpha**2 * beta**2 + Clp = i0 + i1 * alpha + i2 * alpha**2 + i3 * alpha**3 + Clr = j0 + j1 * alpha + j2 * alpha**2 + j3 * alpha**3 + j4 * alpha**4 + Clda = k0 + k1 * alpha + k2 * beta + k3 * alpha**2 + k4 * alpha * beta + k5 * alpha**2 * beta + k6 * alpha**3 + Cldr = l0 + l1 * alpha + l2 * beta + l3 * alpha * beta + l4 * alpha**2 * beta + l5 * alpha**3 * beta + l6 * beta**2 + Cm0 = m0 + m1 * alpha + m2 * de + m3 * alpha * de + m4 * de**2 + m5 * alpha**2 * de + m6 * de**3 + m7 * \ + alpha * de**2 + + + Cmq = n0 + n1 * alpha + n2 * alpha**2 + n3 * alpha**3 + n4 * alpha**4 + n5 * alpha**5 + Cn0 = o0 * beta + o1 * alpha * beta + o2 * beta**2 + o3 * alpha * beta**2 + o4 * alpha**2 * beta + o5 * \ + alpha**2 * beta**2 + o6 * alpha**3 * beta + Cnp = p0 + p1 * alpha + p2 * alpha**2 + p3 * alpha**3 + p4 * alpha**4 + Cnr = q0 + q1 * alpha + q2 * alpha**2 + Cnda = r0 + r1 * alpha + r2 * beta + r3 * alpha * beta + r4 * alpha**2 * beta + r5 * alpha**3 * beta + r6 * \ + alpha**2 + r7 * alpha**3 + r8 * beta**3 + r9 * alpha * beta**3 + Cndr = s0 + s1 * alpha + s2 * beta + s3 * alpha * beta + s4 * alpha**2 * beta + s5 * alpha**2 + ## + + Cx = Cx0 + Cxq * qhat + Cy = Cy0 + Cyp * phat + Cyr * rhat + Cz = Cz0 + Czq * qhat + Cl = Cl0 + Clp * phat + Clr * rhat + Clda * da + Cldr * dr + Cm = Cm0 + Cmq * qhat + Cz * (xcgref - xcg) + Cn = Cn0 + Cnp * phat + Cnr * rhat + Cnda * da + Cndr * dr - Cy * (xcgref - xcg) * (cbar / b) + + return Cx, Cy, Cz, Cl, Cm, Cn diff --git a/demo/F16/aerobench/lowlevel/pdot.py b/demo/F16/aerobench/lowlevel/pdot.py new file mode 100644 index 0000000000000000000000000000000000000000..dbfb1c3200aba86e69fcf4a34090768df018a323 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/pdot.py @@ -0,0 +1,29 @@ +''' +Stanley Bak +Python F-16 +power derivative (pdot) +''' + +from aerobench.lowlevel.rtau import rtau + +def pdot(p3, p1): + 'pdot function' + + if p1 >= 50: + if p3 >= 50: + t = 5 + p2 = p1 + else: + p2 = 60 + t = rtau(p2 - p3) + else: + if p3 >= 50: + t = 5 + p2 = 40 + else: + p2 = p1 + t = rtau(p2 - p3) + + pd = t * (p2 - p3) + + return pd diff --git a/demo/F16/aerobench/lowlevel/rtau.py b/demo/F16/aerobench/lowlevel/rtau.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f00d6e32fc1182722af5c9a8c025b5e88138db --- /dev/null +++ b/demo/F16/aerobench/lowlevel/rtau.py @@ -0,0 +1,18 @@ +''' +Stanley Bak +Python F-16 + +Rtau function +''' + +def rtau(dp): + 'rtau function' + + if dp <= 25: + rt = 1.0 + elif dp >= 50: + rt = .1 + else: + rt = 1.9 - .036 * dp + + return rt diff --git a/demo/F16/aerobench/lowlevel/subf16_model.py b/demo/F16/aerobench/lowlevel/subf16_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b807b16520a8b735bef9bc17890c9c0219b4c69d --- /dev/null +++ b/demo/F16/aerobench/lowlevel/subf16_model.py @@ -0,0 +1,201 @@ +''' +Stanley Bak +Python F-16 subf16 +outputs aircraft state vector deriative +''' + +# x[0] = air speed, VT (ft/sec) +# x[1] = angle of attack, alpha (rad) +# x[2] = angle of sideslip, beta (rad) +# x[3] = roll angle, phi (rad) +# x[4] = pitch angle, theta (rad) +# x[5] = yaw angle, psi (rad) +# x[6] = roll rate, P (rad/sec) +# x[7] = pitch rate, Q (rad/sec) +# x[8] = yaw rate, R (rad/sec) +# x[9] = northward horizontal displacement, pn (feet) +# x[10] = eastward horizontal displacement, pe (feet) +# x[11] = altitude, h (feet) +# x[12] = engine thrust dynamics lag state, pow +# +# u[0] = throttle command 0.0 < u(1) < 1.0 +# u[1] = elevator command in degrees +# u[2] = aileron command in degrees +# u[3] = rudder command in degrees +# + +from math import sin, cos, pi + +from aerobench.lowlevel.adc import adc +from aerobench.lowlevel.tgear import tgear +from aerobench.lowlevel.pdot import pdot +from aerobench.lowlevel.thrust import thrust +from aerobench.lowlevel.cx import cx +from aerobench.lowlevel.cy import cy +from aerobench.lowlevel.cz import cz +from aerobench.lowlevel.cl import cl +from aerobench.lowlevel.dlda import dlda +from aerobench.lowlevel.dldr import dldr +from aerobench.lowlevel.cm import cm +from aerobench.lowlevel.cn import cn +from aerobench.lowlevel.dnda import dnda +from aerobench.lowlevel.dndr import dndr +from aerobench.lowlevel.dampp import dampp + +from aerobench.lowlevel.morellif16 import Morellif16 + +def subf16_model(x, u, model, adjust_cy=True): + '''output aircraft state vector derivative for a given input + + The reference for the model is Appendix A of Stevens & Lewis + ''' + + assert model in ['stevens', 'morelli'] + assert len(x) == 13 + assert len(u) == 4 + + xcg = 0.35 + + thtlc, el, ail, rdr = u + + s = 300 + b = 30 + cbar = 11.32 + rm = 1.57e-3 + xcgr = .35 + he = 160.0 + c1 = -.770 + c2 = .02755 + c3 = 1.055e-4 + c4 = 1.642e-6 + c5 = .9604 + c6 = 1.759e-2 + c7 = 1.792e-5 + c8 = -.7336 + c9 = 1.587e-5 + rtod = 57.29578 + g = 32.17 + + xd = x.copy() + vt = x[0] + alpha = x[1]*rtod + beta = x[2]*rtod + phi = x[3] + theta = x[4] + psi = x[5] + p = x[6] + q = x[7] + r = x[8] + alt = x[11] + power = x[12] + + # air data computer and engine model + amach, qbar = adc(vt, alt) + cpow = tgear(thtlc) + + xd[12] = pdot(power, cpow) + + t = thrust(power, alt, amach) + dail = ail/20 + drdr = rdr/30 + + # component build up + + if model == 'stevens': + # stevens & lewis (look up table version) + cxt = cx(alpha, el) + cyt = cy(beta, ail, rdr) + czt = cz(alpha, beta, el) + + clt = cl(alpha, beta) + dlda(alpha, beta) * dail + dldr(alpha, beta) * drdr + cmt = cm(alpha, el) + cnt = cn(alpha, beta) + dnda(alpha, beta) * dail + dndr(alpha, beta) * drdr + else: + # morelli model (polynomial version) + cxt, cyt, czt, clt, cmt, cnt = Morellif16(alpha*pi/180, beta*pi/180, el*pi/180, ail*pi/180, rdr*pi/180, \ + p, q, r, cbar, b, vt, xcg, xcgr) + + # add damping derivatives + + tvt = .5 / vt + b2v = b * tvt + cq = cbar * q * tvt + + # get ready for state equations + d = dampp(alpha) + cxt = cxt + cq * d[0] + cyt = cyt + b2v * (d[1] * r + d[2] * p) + czt = czt + cq * d[3] + clt = clt + b2v * (d[4] * r + d[5] * p) + cmt = cmt + cq * d[6] + czt * (xcgr-xcg) + cnt = cnt + b2v * (d[7] * r + d[8] * p)-cyt * (xcgr-xcg) * cbar/b + cbta = cos(x[2]) + u = vt * cos(x[1]) * cbta + v = vt * sin(x[2]) + w = vt * sin(x[1]) * cbta + sth = sin(theta) + cth = cos(theta) + sph = sin(phi) + cph = cos(phi) + spsi = sin(psi) + cpsi = cos(psi) + qs = qbar * s + qsb = qs * b + rmqs = rm * qs + gcth = g * cth + qsph = q * sph + ay = rmqs * cyt + az = rmqs * czt + + # force equations + udot = r * v-q * w-g * sth + rm * (qs * cxt + t) + vdot = p * w-r * u + gcth * sph + ay + wdot = q * u-p * v + gcth * cph + az + dum = (u * u + w * w) + + xd[0] = (u * udot + v * vdot + w * wdot)/vt + xd[1] = (u * wdot-w * udot)/dum + xd[2] = (vt * vdot-v * xd[0]) * cbta/dum + + # kinematics + xd[3] = p + (sth/cth) * (qsph + r * cph) + xd[4] = q * cph-r * sph + xd[5] = (qsph + r * cph)/cth + + # moments + xd[6] = (c2 * p + c1 * r + c4 * he) * q + qsb * (c3 * clt + c4 * cnt) + + xd[7] = (c5 * p-c7 * he) * r + c6 * (r * r-p * p) + qs * cbar * c7 * cmt + xd[8] = (c8 * p-c2 * r + c9 * he) * q + qsb * (c4 * clt + c9 * cnt) + + # navigation + t1 = sph * cpsi + t2 = cph * sth + t3 = sph * spsi + s1 = cth * cpsi + s2 = cth * spsi + s3 = t1 * sth-cph * spsi + s4 = t3 * sth + cph * cpsi + s5 = sph * cth + s6 = t2 * cpsi + t3 + s7 = t2 * spsi-t1 + s8 = cph * cth + xd[9] = u * s1 + v * s3 + w * s6 # north speed + xd[10] = u * s2 + v * s4 + w * s7 # east speed + xd[11] = u * sth-v * s5-w * s8 # vertical speed + + # outputs + + xa = 15.0 # sets distance normal accel is in front of the c.g. (xa = 15.0 at pilot) + az = az-xa * xd[7] # moves normal accel in front of c.g. + + #################################### + ###### peter additions below ###### + if adjust_cy: + ay = ay+xa*xd[8] # moves side accel in front of c.g. + + # For extraction of Nz + Nz = (-az / g) - 1 # zeroed at 1 g, positive g = pulling up + Ny = ay / g + + return xd, Nz, Ny, az, ay diff --git a/demo/F16/aerobench/lowlevel/tgear.py b/demo/F16/aerobench/lowlevel/tgear.py new file mode 100644 index 0000000000000000000000000000000000000000..8c0fe4735b581a4e947099e12c2a671f68b7ab59 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/tgear.py @@ -0,0 +1,14 @@ +''' +Stanley Bak +Python F-16 GCAS +''' + +def tgear(thtl): + 'tgear function' + + if thtl <= .77: + tg = 64.94 * thtl + else: + tg = 217.38 * thtl - 117.38 + + return tg diff --git a/demo/F16/aerobench/lowlevel/thrust.py b/demo/F16/aerobench/lowlevel/thrust.py new file mode 100644 index 0000000000000000000000000000000000000000..3de5f8b18481376a24d7031c2b1c96af08e75e91 --- /dev/null +++ b/demo/F16/aerobench/lowlevel/thrust.py @@ -0,0 +1,76 @@ +''' +Stanle Bak +Python F-16 +Thrust function +''' + +import numpy as np + +from aerobench.util import fix + +def thrust(power, alt, rmach): + 'thrust lookup-table version' + + a = np.array([[1060, 670, 880, 1140, 1500, 1860], \ + [635, 425, 690, 1010, 1330, 1700], \ + [60, 25, 345, 755, 1130, 1525], \ + [-1020, -170, -300, 350, 910, 1360], \ + [-2700, -1900, -1300, -247, 600, 1100], \ + [-3600, -1400, -595, -342, -200, 700]], dtype=float).T + + b = np.array([[12680, 9150, 6200, 3950, 2450, 1400], \ + [12680, 9150, 6313, 4040, 2470, 1400], \ + [12610, 9312, 6610, 4290, 2600, 1560], \ + [12640, 9839, 7090, 4660, 2840, 1660], \ + [12390, 10176, 7750, 5320, 3250, 1930], \ + [11680, 9848, 8050, 6100, 3800, 2310]], dtype=float).T + + c = np.array([[20000, 15000, 10800, 7000, 4000, 2500], \ + [21420, 15700, 11225, 7323, 4435, 2600], \ + [22700, 16860, 12250, 8154, 5000, 2835], \ + [24240, 18910, 13760, 9285, 5700, 3215], \ + [26070, 21075, 15975, 11115, 6860, 3950], \ + [28886, 23319, 18300, 13484, 8642, 5057]], dtype=float).T + + if alt < 0: + alt = 0.01 # uh, why not 0? + + h = .0001 * alt + + i = fix(h) + + if i >= 5: + i = 4 + + dh = h - i + rm = 5 * rmach + m = fix(rm) + + if m >= 5: + m = 4 + elif m <= 0: + m = 0 + + dm = rm - m + cdh = 1 - dh + + # do not increment these, since python is 0-indexed while matlab is 1-indexed + #i = i + 1 + #m = m + 1 + + s = b[i, m] * cdh + b[i + 1, m] * dh + t = b[i, m + 1] * cdh + b[i + 1, m + 1] * dh + tmil = s + (t - s) * dm + + if power < 50: + s = a[i, m] * cdh + a[i + 1, m] * dh + t = a[i, m + 1] * cdh + a[i + 1, m + 1] * dh + tidl = s + (t - s) * dm + thrst = tidl + (tmil - tidl) * power * .02 + else: + s = c[i, m] * cdh + c[i + 1, m] * dh + t = c[i, m + 1] * cdh + c[i + 1, m + 1] * dh + tmax = s + (t - s) * dm + thrst = tmil + (tmax - tmil) * (power - 50) * .02 + + return thrst diff --git a/demo/F16/aerobench/run_f16_sim.py b/demo/F16/aerobench/run_f16_sim.py new file mode 100644 index 0000000000000000000000000000000000000000..59f5633b9a9a4856b2566584ae264b020c758ad3 --- /dev/null +++ b/demo/F16/aerobench/run_f16_sim.py @@ -0,0 +1,225 @@ +''' +Stanley Bak +run_f16_sim python version +''' + +import time + +import numpy as np +from scipy.integrate import RK45 + +from aerobench.highlevel.controlled_f16 import controlled_f16 +from aerobench.util import get_state_names, Euler + +from dryvr_plus_plus import BaseAgent + + +class F16Agent(BaseAgent): + '''Dynamics of an F16 aircraft + derived from Stanley Bak's python library''' + def __init__(self, id, code = None, file_name = None): + '''Contructor for one F16 agent + EXACTLY one of the following should be given + file_name: name of the controller + code: pyhton string ddefning the controller + ''' + # Calling the constructor of tha base class + super().__init__(id, code, file_name) + +def run_f16_sim(initial_state, tmax, ap, step=1/30, extended_states=False, model_str='morelli', + integrator_str='rk45', v2_integrators=False): + '''Simulates and analyzes autonomous F-16 maneuvers + + if multiple aircraft are to be simulated at the same time, + initial_state should be the concatenated full (including integrators) initial state. + + returns a dict with the following keys: + + 'status': integration status, should be 'finished' if no errors, or 'autopilot finished' + 'times': time history + 'states': state history at each time step + 'modes': mode history at each time step + + if extended_states was True, result also includes: + 'xd_list' - derivative at each time step + 'ps_list' - ps at each time step + 'Nz_list' - Nz at each time step + 'Ny_r_list' - Ny_r at each time step + 'u_list' - input at each time step, input is 7-tuple: throt, ele, ail, rud, Nz_ref, ps_ref, Ny_r_ref + These are tuples if multiple aircraft are used + ''' + + start = time.perf_counter() + + initial_state = np.array(initial_state, dtype=float) + llc = ap.llc + + num_vars = len(get_state_names()) + llc.get_num_integrators() + + if initial_state.size < num_vars: + # append integral error states to state vector + x0 = np.zeros(num_vars) + x0[:initial_state.shape[0]] = initial_state + else: + x0 = initial_state + + assert x0.size % num_vars == 0, f"expected initial state ({x0.size} vars) to be multiple of {num_vars} vars" + + # run the numerical simulation + times = [0] + states = [x0] + + # mode can change at time 0 + ap.advance_discrete_mode(times[-1], states[-1]) + + modes = [ap.mode] + + if extended_states: + xd, u, Nz, ps, Ny_r = get_extended_states(ap, times[-1], states[-1], model_str, v2_integrators) + + xd_list = [xd] + u_list = [u] + Nz_list = [Nz] + ps_list = [ps] + Ny_r_list = [Ny_r] + + der_func = make_der_func(ap, model_str, v2_integrators) + + if integrator_str == 'rk45': + integrator_class = RK45 + kwargs = {} + else: + assert integrator_str == 'euler' + integrator_class = Euler + kwargs = {'step': step} + + # note: fixed_step argument is unused by rk45, used with euler + integrator = integrator_class(der_func, times[-1], states[-1], tmax, **kwargs) + + while integrator.status == 'running': + integrator.step() + + if integrator.t >= times[-1] + step: + dense_output = integrator.dense_output() + + while integrator.t >= times[-1] + step: + t = times[-1] + step + #print(f"{round(t, 2)} / {tmax}") + + times.append(t) + states.append(dense_output(t)) + + updated = ap.advance_discrete_mode(times[-1], states[-1]) + modes.append(ap.mode) + + # re-run dynamics function at current state to get non-state variables + if extended_states: + xd, u, Nz, ps, Ny_r = get_extended_states(ap, times[-1], states[-1], model_str, v2_integrators) + + xd_list.append(xd) + u_list.append(u) + + Nz_list.append(Nz) + ps_list.append(ps) + Ny_r_list.append(Ny_r) + + if ap.is_finished(times[-1], states[-1]): + # this both causes the outer loop to exit and sets res['status'] appropriately + integrator.status = 'autopilot finished' + break + + if updated: + # re-initialize the integration class on discrete mode switches + integrator = integrator_class(der_func, times[-1], states[-1], tmax, **kwargs) + break + + assert 'finished' in integrator.status + + res = {} + res['status'] = integrator.status + res['times'] = times + res['states'] = np.array(states, dtype=float) + res['modes'] = modes + + if extended_states: + res['xd_list'] = xd_list + res['ps_list'] = ps_list + res['Nz_list'] = Nz_list + res['Ny_r_list'] = Ny_r_list + res['u_list'] = u_list + + res['runtime'] = time.perf_counter() - start + + return res + +def make_der_func(ap, model_str, v2_integrators): + 'make the combined derivative function for integration' + + def der_func(t, full_state): + 'derivative function, generalized for multiple aircraft' + + u_refs = ap.get_checked_u_ref(t, full_state) + + num_aircraft = u_refs.size // 4 + num_vars = len(get_state_names()) + ap.llc.get_num_integrators() + assert full_state.size // num_vars == num_aircraft + + xds = [] + + for i in range(num_aircraft): + state = full_state[num_vars*i:num_vars*(i+1)] + u_ref = u_refs[4*i:4*(i+1)] + + xd = controlled_f16(t, state, u_ref, ap.llc, model_str, v2_integrators)[0] + xds.append(xd) + + rv = np.hstack(xds) + + return rv + + return der_func + +def get_extended_states(ap, t, full_state, model_str, v2_integrators): + '''get xd, u, Nz, ps, Ny_r at the current time / state + + returns tuples if more than one aircraft + ''' + + llc = ap.llc + num_vars = len(get_state_names()) + llc.get_num_integrators() + num_aircraft = full_state.size // num_vars + + xd_tup = [] + u_tup = [] + Nz_tup = [] + ps_tup = [] + Ny_r_tup = [] + + u_refs = ap.get_checked_u_ref(t, full_state) + + for i in range(num_aircraft): + state = full_state[num_vars*i:num_vars*(i+1)] + u_ref = u_refs[4*i:4*(i+1)] + + xd, u, Nz, ps, Ny_r = controlled_f16(t, state, u_ref, llc, model_str, v2_integrators) + + xd_tup.append(xd) + u_tup.append(u) + Nz_tup.append(Nz) + ps_tup.append(ps) + Ny_r_tup.append(Ny_r) + + if num_aircraft == 1: + rv_xd = xd_tup[0] + rv_u = u_tup[0] + rv_Nz = Nz_tup[0] + rv_ps = ps_tup[0] + rv_Ny_r = Ny_r_tup[0] + else: + rv_xd = tuple(xd_tup) + rv_u = tuple(u_tup) + rv_Nz = tuple(Nz_tup) + rv_ps = tuple(ps_tup) + rv_Ny_r = tuple(Ny_r_tup) + + return rv_xd, rv_u, rv_Nz, rv_ps, rv_Ny_r diff --git a/demo/F16/aerobench/util.py b/demo/F16/aerobench/util.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2c4796629e0f3416452ab99ef538d9c973ce57 --- /dev/null +++ b/demo/F16/aerobench/util.py @@ -0,0 +1,287 @@ +''' +Utilities for F-16 GCAS +''' + +from math import floor, ceil +import numpy as np + +class StateIndex: + 'list of static state indices' + + VT = 0 + VEL = 0 # alias + + ALPHA = 1 + BETA = 2 + PHI = 3 # roll angle + THETA = 4 # pitch angle + PSI = 5 # yaw angle + + P = 6 + Q = 7 + R = 8 + + POSN = 9 + POS_N = 9 + + POSE = 10 + POS_E = 10 + + ALT = 11 + H = 11 + + POW = 12 + +class Freezable(): + 'a class where you can freeze the fields (prevent new fields from being created)' + + _frozen = False + + def freeze_attrs(self): + 'prevents any new attributes from being created in the object' + self._frozen = True + + def __setattr__(self, key, value): + if self._frozen and not hasattr(self, key): + raise TypeError("{} does not contain attribute '{}' (object was frozen)".format(self, key)) + + object.__setattr__(self, key, value) + +class Euler(Freezable): + '''fixed step euler integration + + loosely based on scipy.integrate.RK45 + ''' + + def __init__(self, der_func, tstart, ystart, tend, step=0, time_tol=1e-9): + assert step > 0, "arg step > 0 required in Euler integrator" + assert tend > tstart + + self.der_func = der_func # signature (t, x) + self.tstep = step + self.t = tstart + self.y = ystart.copy() + self.yprev = None + self.tprev = None + self.tend = tend + + self.status = 'running' + + self.time_tol = time_tol + + self.freeze_attrs() + + def step(self): + 'take one step' + + if self.status == 'running': + self.yprev = self.y.copy() + self.tprev = self.t + yd = self.der_func(self.t, self.y) + + self.t += self.tstep + + if self.t + self.time_tol >= self.tend: + self.t = self.tend + + dt = self.t - self.tprev + self.y += dt * yd + + if self.t == self.tend: + self.status = 'finished' + + def dense_output(self): + 'return a function of time' + + assert self.tprev is not None + + dy = self.y - self.yprev + dt = self.t - self.tprev + + dydt = dy / dt + + def fun(t): + 'return state at time t (linear interpolation)' + + deltat = t - self.tprev + + return self.yprev + dydt * deltat + + return fun + +def get_state_names(): + 'returns a list of state variable names' + + return ['vt', 'alpha', 'beta', 'phi', 'theta', 'psi', 'P', 'Q', 'R', 'pos_n', 'pos_e', 'alt', 'pow'] + +def printmat(mat, main_label, row_label_str, col_label_str): + 'print a matrix' + + if isinstance(row_label_str, list) and len(row_label_str) == 0: + row_label_str = None + + assert isinstance(main_label, str) + assert row_label_str is None or isinstance(row_label_str, str) + assert isinstance(col_label_str, str) + + mat = np.array(mat) + if len(mat.shape) == 1: + mat.shape = (1, mat.shape[0]) # one-row matrix + + print("{main_label} =") + + row_labels = None if row_label_str is None else row_label_str.split(' ') + col_labels = col_label_str.split(' ') + + width = 7 + + width = max(width, max([len(l) for l in col_labels])) + + if row_labels is not None: + width = max(width, max([len(l) for l in row_labels])) + + width += 1 + + # add blank space for row labels + if row_labels is not None: + print("{: <{}}".format('', width), end='') + + # print col lables + for col_label in col_labels: + if len(col_label) > width: + col_label = col_label[:width] + + print("{: >{}}".format(col_label, width), end='') + + print('') + + if row_labels is not None: + assert len(row_labels) == mat.shape[0], \ + "row labels (len={}) expected one element for each row of the matrix ({})".format( \ + len(row_labels), mat.shape[0]) + + for r in range(mat.shape[0]): + row = mat[r] + + if row_labels is not None: + label = row_labels[r] + + if len(label) > width: + label = label[:width] + + print("{:<{}}".format(label, width), end='') + + for num in row: + #print("{:#<{}}".format(num, width), end='') + print("{:{}.{}g}".format(num, width, width-3), end='') + + print('') + + +def fix(ele): + 'round towards zero' + + assert isinstance(ele, float) + + if ele > 0: + rv = int(floor(ele)) + else: + rv = int(ceil(ele)) + + return rv + +def sign(ele): + 'sign of a number' + + if ele < 0: + rv = -1 + elif ele == 0: + rv = 0 + else: + rv = 1 + + return rv + +def extract_single_result(res, index, llc): + 'extract a res object for a sinlge aircraft from a multi-aircraft simulation' + + num_vars = len(get_state_names()) + llc.get_num_integrators() + num_aircraft = res['states'][0].size // num_vars + + if num_aircraft == 1: + assert index == 0 + rv = res + else: + rv = {} + rv['status'] = res['status'] + rv['times'] = res['times'] + rv['modes'] = res['modes'] + + full_states = res['states'] + rv['states'] = full_states[:, num_vars*index:num_vars*(index+1)] + + if 'xd_list' in res: + # extended states + key_list = ['xd_list', 'ps_list', 'Nz_list', 'Ny_r_list', 'u_list'] + + for key in key_list: + rv[key] = [tup[index] for tup in res[key]] + + return rv + +class SafetyLimits(Freezable): + 'a class for holding a set of safety limits.' + + def __init__(self, **kwargs): + self.altitude = kwargs['altitude'] if 'altitude' in kwargs and kwargs['altitude'] is not None else None + self.Nz = kwargs['Nz'] if 'Nz' in kwargs and kwargs['Nz'] is not None else None + self.v = kwargs['v'] if 'v' in kwargs and kwargs['v'] is not None else None + self.alpha = kwargs['alpha'] if 'alpha' in kwargs and kwargs['alpha'] is not None else None + + self.psMaxAccelDeg = kwargs['psMaxAccelDeg'] if 'psMaxAccelDeg' in kwargs and kwargs['psMaxAccelDeg'] is not None else None + self.betaMaxDeg = kwargs['betaMaxDeg'] if 'betaMaxDeg' in kwargs and kwargs['betaMaxDeg'] is not None else None + + self.freeze_attrs() + +class SafetyLimitsVerifier(Freezable): + 'given some limits (in a SafetyLimits) and optional low-level controller (in a LowLevelController), verify whether the simulation results are safe.' + + def __init__(self, safety_limits, llc=None): + self.safety_limits = safety_limits + self.llc = llc + + def verify(self, results): + # Determine the number of state variables per tick of the simulation. + if self.llc is not None: + num_state_vars = len(get_state_names()) + \ + self.llc.get_num_integrators() + else: + num_state_vars = len(get_state_names()) + # Check whether the results are sane. + assert (results['states'].size % num_state_vars) == 0, \ + "Wrong number of state variables." + + # Go through each tick of the simulation and determine whether + # the object(s) was (were) in a safe state. + for i in range(results['states'].size // num_state_vars): + _vt, alpha, beta, _phi, \ + _theta, _psi, _p, _q, \ + _r, _pos_n, _pos_e, alt, \ + _, _, _, _ = results['states'][i] + nz = results['Nz_list'][i] + ps = results['ps_list'][i] + + if self.safety_limits.altitude is not None: + assert self.safety_limits.altitude[0] <= alt <= self.safety_limits.altitude[1], "Altitude ({}) is not within the specified limits ({}, {}).".format(alt, self.safety_limits.altitude[0], self.safety_limits.altitude[1]) + + if self.safety_limits.Nz is not None: + assert self.safety_limits.Nz[0] <= nz <= self.safety_limits.Nz[1], "Nz ({}) is not within the specified limits ({}, {}).".format(nz, self.safety_limits.Nz[0], self.safety_limits.Nz[1]) + + if self.safety_limits.alpha is not None: + assert self.safety_limits.alpha[0] <= alpha <= self.safety_limits.alpha[1], "alpha ({}) is not within the specified limits ({}, {}).".format(nz, self.safety_limits.alpha[0], self.safety_limits.alpha[1]) + + if self.safety_limits.psMaxAccelDeg is not None: + assert ps <= self.safety_limits.psMaxAccelDeg, "Ps is not less than the specified max." + + if self.safety_limits.betaMaxDeg is not None: + assert beta <= self.safety_limits.betaMaxDeg, "Beta is not less than the specified max." diff --git a/demo/F16/aerobench/visualize/anim3d.py b/demo/F16/aerobench/visualize/anim3d.py new file mode 100644 index 0000000000000000000000000000000000000000..9042af44979c1f5ef92f4da2ccf967047526fe18 --- /dev/null +++ b/demo/F16/aerobench/visualize/anim3d.py @@ -0,0 +1,404 @@ +''' +3d plotting utilities for aerobench +''' + +import math +import time +import os +import traceback + +from scipy.io import loadmat + +import numpy as np +from numpy import rad2deg + +# these imports are needed for 3d plotting +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection + +import matplotlib +import matplotlib.animation as animation +import matplotlib.pyplot as plt + +from aerobench.visualize import plot +from aerobench.util import StateIndex + +def get_script_path(filename=__file__): + '''get the path this script''' + return os.path.dirname(os.path.realpath(filename)) + +def make_anim(res, filename, viewsize=1000, viewsize_z=1000, f16_scale=30, trail_pts=60, + elev=30, azim=45, skip_frames=None, chase=False, fixed_floor=False, + init_extra=None, update_extra=None): + ''' + make a 3d plot of the GCAS maneuver. + + see examples/anim3d folder for examples on usage + ''' + + plot.init_plot() + start = time.time() + + if not isinstance(res, list): + res = [res] + + if not isinstance(viewsize, list): + viewsize = [viewsize] + + if not isinstance(viewsize_z, list): + viewsize_z = [viewsize_z] + + if not isinstance(f16_scale, list): + f16_scale = [f16_scale] + + if not isinstance(trail_pts, list): + trail_pts = [trail_pts] + + if not isinstance(elev, list): + elev = [elev] + + if not isinstance(azim, list): + azim = [azim] + + if not isinstance(skip_frames, list): + skip_frames = [skip_frames] + + if not isinstance(chase, list): + chase = [chase] + + if not isinstance(fixed_floor, list): + fixed_floor = [fixed_floor] + + if not isinstance(init_extra, list): + init_extra = [init_extra] + + if not isinstance(update_extra, list): + update_extra = [update_extra] + + ##### + # fill in defaults + if filename == '': + full_plot = False + else: + full_plot = True + + for i, skip in enumerate(skip_frames): + if skip is not None: + continue + + if filename == '': # plot to the screen + skip_frames[i] = 5 + elif filename.endswith('.gif'): + skip_frames[i] = 2 + else: + skip_frames[i] = 1 # plot every frame + + if filename == '': + filename = None + + ## + all_times = [] + all_states = [] + all_modes = [] + all_ps_list = [] + all_Nz_list = [] + + for r, skip in zip(res, skip_frames): + t = r['times'] + s = r['states'] + m = r['modes'] + ps = r['ps_list'] + Nz = r['Nz_list'] + + t = t[0::skip] + s = s[0::skip] + m = m[0::skip] + ps = ps[0::skip] + Nz = Nz[0::skip] + + all_times.append(t) + all_states.append(s) + all_modes.append(m) + all_ps_list.append(ps) + all_Nz_list.append(Nz) + + ## + + fig = plt.figure(figsize=(8, 7)) + ax = fig.add_subplot(111, projection='3d') + + ## + + parent = get_script_path() + plane_point_data = os.path.join(parent, 'f-16.mat') + + data = loadmat(plane_point_data) + f16_pts = data['V'] + f16_faces = data['F'] + + plane_polys = Poly3DCollection([], color=None if full_plot else 'k') + ax.add_collection3d(plane_polys) + + ax.set_xlabel('X [ft]', fontsize=14) + ax.set_ylabel('Y [ft]', fontsize=14) + ax.set_zlabel('Altitude [ft]', fontsize=14) + + # text + fontsize = 14 + time_text = ax.text2D(0.05, 0.97, "", transform=ax.transAxes, fontsize=fontsize) + mode_text = ax.text2D(0.95, 0.97, "", transform=ax.transAxes, fontsize=fontsize, horizontalalignment='right') + + alt_text = ax.text2D(0.05, 0.93, "", transform=ax.transAxes, fontsize=fontsize) + v_text = ax.text2D(0.95, 0.93, "", transform=ax.transAxes, fontsize=fontsize, horizontalalignment='right') + + alpha_text = ax.text2D(0.05, 0.89, "", transform=ax.transAxes, fontsize=fontsize) + beta_text = ax.text2D(0.95, 0.89, "", transform=ax.transAxes, fontsize=fontsize, horizontalalignment='right') + + nz_text = ax.text2D(0.05, 0.85, "", transform=ax.transAxes, fontsize=fontsize) + ps_text = ax.text2D(0.95, 0.85, "", transform=ax.transAxes, fontsize=fontsize, horizontalalignment='right') + + ang_text = ax.text2D(0.5, 0.81, "", transform=ax.transAxes, fontsize=fontsize, horizontalalignment='center') + + trail_line, = ax.plot([], [], [], color='r', lw=2, zorder=50) + + extra_lines = [] + + for func in init_extra: + if func is not None: + extra_lines.append(func(ax)) + else: + extra_lines.append([]) + + first_frames = [] + frames = 0 + + for t in all_times: + first_frames.append(frames) + frames += len(t) + + def anim_func(global_frame): + 'updates for the animation frame' + + index = 0 + first_frame = False + + for i, f in enumerate(first_frames): + if global_frame >= f: + index = i + + if global_frame == f: + first_frame = True + break + + frame = global_frame - first_frames[index] + states = all_states[index] + times = all_times[index] + modes = all_modes[index] + Nz_list = all_Nz_list[index] + ps_list = all_ps_list[index] + + print(f"Frame: {global_frame}/{frames} - Index {index} frame {frame}/{len(times)}") + + speed = states[frame][0] + alpha = states[frame][1] + beta = states[frame][2] + alt = states[frame][11] + + phi = states[frame][StateIndex.PHI] + theta = states[frame][StateIndex.THETA] + psi = states[frame][StateIndex.PSI] + + dx = states[frame][StateIndex.POS_E] + dy = states[frame][StateIndex.POS_N] + dz = states[frame][StateIndex.ALT] + + if first_frame: + ax.view_init(elev[index], azim[index]) + + for i, lines in enumerate(extra_lines): + for line in lines: + line.set_visible(i == index) + + time_text.set_text('t = {:.2f} sec'.format(times[frame])) + + if chase[index]: + ax.view_init(elev[index], rad2deg(-psi) - 90.0) + + colors = ['red', 'blue', 'green', 'magenta'] + + mode_names = [] + + for mode in modes: + if not mode in mode_names: + mode_names.append(mode) + + mode = modes[frame] + mode_index = modes.index(mode) + col = colors[mode_index % len(colors)] + mode_text.set_color(col) + mode_text.set_text('Mode: {}'.format(mode.capitalize())) + + alt_text.set_text('h = {:.2f} ft'.format(alt)) + v_text.set_text('V = {:.2f} ft/sec'.format(speed)) + + alpha_text.set_text('$\\alpha$ = {:.2f} deg'.format(rad2deg(alpha))) + beta_text.set_text('$\\beta$ = {:.2f} deg'.format(rad2deg(beta))) + + nz_text.set_text('$N_z$ = {:.2f} g'.format(Nz_list[frame])) + ps_text.set_text('$p_s$ = {:.2f} deg/sec'.format(rad2deg(ps_list[frame]))) + + ang_text.set_text('[$\\phi$, $\\theta$, $\\psi$] = [{:.2f}, {:.2f}, {:.2f}] deg'.format(\ + rad2deg(phi), rad2deg(theta), rad2deg(psi))) + + s = f16_scale[index] + s = 25 if s is None else s + pts = scale3d(f16_pts, [-s, s, s]) + + pts = rotate3d(pts, theta, psi - math.pi/2, -phi) + + size = viewsize[index] + size = 1000 if size is None else size + minx = dx - size + maxx = dx + size + miny = dy - size + maxy = dy + size + + vz = viewsize_z[index] + vz = 1000 if vz is None else vz + + if fixed_floor[index]: + minz = 0 + maxz = vz + else: + minz = dz - vz + maxz = dz + vz + + ax.set_xlim([minx, maxx]) + ax.set_ylim([miny, maxy]) + ax.set_zlim([minz, maxz]) + + verts = [] + fc = [] + ec = [] + count = 0 + + # draw ground + if minz <= 0 <= maxz: + z = 0 + verts.append([(minx, miny, z), (maxx, miny, z), (maxx, maxy, z), (minx, maxy, z)]) + fc.append('0.8') + ec.append('0.8') + + # draw f16 + for face in f16_faces: + face_pts = [] + + count = count + 1 + + if not full_plot and count % 10 != 0: + continue + + for findex in face: + face_pts.append((pts[findex-1][0] + dx, \ + pts[findex-1][1] + dy, \ + pts[findex-1][2] + dz)) + + verts.append(face_pts) + fc.append('0.2') + ec.append('0.2') + + plane_polys.set_verts(verts) + plane_polys.set_facecolor(fc) + plane_polys.set_edgecolor(ec) + + # do trail + t = trail_pts[index] + t = 200 if t is None else t + trail_len = t // skip_frames[index] + start_index = max(0, frame-trail_len) + + pos_xs = [pt[StateIndex.POS_E] for pt in states] + pos_ys = [pt[StateIndex.POS_N] for pt in states] + pos_zs = [pt[StateIndex.ALT] for pt in states] + + trail_line.set_data(np.asarray(pos_xs[start_index:frame]), np.asarray(pos_ys[start_index:frame])) + trail_line.set_3d_properties(np.asarray(pos_zs[start_index:frame])) + + if update_extra[index] is not None: + update_extra[index](frame) + + plt.tight_layout() + + interval = 30 + + if filename.endswith('.gif'): + interval = 60 + + anim_obj = animation.FuncAnimation(fig, anim_func, frames, interval=interval, \ + blit=False, repeat=True) + + if filename is not None: + + if filename.endswith('.gif'): + print("\nSaving animation to '{}' using 'imagemagick'...".format(filename)) + anim_obj.save(filename, dpi=60, writer='imagemagick') # dpi was 80 + print("Finished saving to {} in {:.1f} sec".format(filename, time.time() - start)) + else: + fps = 40 + codec = 'libx264' + + print("\nSaving '{}' at {:.2f} fps using ffmpeg with codec '{}'.".format(filename, fps, codec)) + + # if this fails do: 'sudo apt-get install ffmpeg' + try: + extra_args = [] + + if codec is not None: + extra_args += ['-vcodec', str(codec)] + + anim_obj.save(filename, fps=fps, extra_args=extra_args) + print("Finished saving to {} in {:.1f} sec".format(filename, time.time() - start)) + except AttributeError: + traceback.print_exc() + print("\nSaving video file failed! Is ffmpeg installed? Can you run 'ffmpeg' in the terminal?") + else: + plt.show() + +def scale3d(pts, scale_list): + 'scale a 3d ndarray of points, and return the new ndarray' + + assert len(scale_list) == 3 + + rv = np.zeros(pts.shape) + + for i in range(pts.shape[0]): + for d in range(3): + rv[i, d] = scale_list[d] * pts[i, d] + + return rv + +def rotate3d(pts, theta, psi, phi): + 'rotates an ndarray of 3d points, returns new list' + + sinTheta = math.sin(theta) + cosTheta = math.cos(theta) + sinPsi = math.sin(psi) + cosPsi = math.cos(psi) + sinPhi = math.sin(phi) + cosPhi = math.cos(phi) + + transform_matrix = np.array([ \ + [cosPsi * cosTheta, -sinPsi * cosTheta, sinTheta], \ + [cosPsi * sinTheta * sinPhi + sinPsi * cosPhi, \ + -sinPsi * sinTheta * sinPhi + cosPsi * cosPhi, \ + -cosTheta * sinPhi], \ + [-cosPsi * sinTheta * cosPhi + sinPsi * sinPhi, \ + sinPsi * sinTheta * cosPhi + cosPsi * sinPhi, \ + cosTheta * cosPhi]], dtype=float) + + rv = np.zeros(pts.shape) + + for i in range(pts.shape[0]): + rv[i] = np.dot(pts[i], transform_matrix) + + return rv diff --git a/demo/F16/aerobench/visualize/bak_matplotlib.mlpstyle b/demo/F16/aerobench/visualize/bak_matplotlib.mlpstyle new file mode 100644 index 0000000000000000000000000000000000000000..58c7faed1401e5334c5489c2343574ae9036ae4e --- /dev/null +++ b/demo/F16/aerobench/visualize/bak_matplotlib.mlpstyle @@ -0,0 +1,8 @@ +font.family : serif +xtick.labelsize : 14 +ytick.labelsize : 14 + +axes.labelsize : 20 +axes.titlesize : 28 + +path.simplify : False diff --git a/demo/F16/aerobench/visualize/f-16.mat b/demo/F16/aerobench/visualize/f-16.mat new file mode 100644 index 0000000000000000000000000000000000000000..dd42bdfb56ac815ef1c1cb6259d92a7d7b8c484d Binary files /dev/null and b/demo/F16/aerobench/visualize/f-16.mat differ diff --git a/demo/F16/aerobench/visualize/plot.py b/demo/F16/aerobench/visualize/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..c690145292ed4fdb1fba6ea08dbcf0bfeacd73a0 --- /dev/null +++ b/demo/F16/aerobench/visualize/plot.py @@ -0,0 +1,294 @@ +''' +Stanley Bak +Python code for F-16 animation video output +''' + +import math +import os + +import numpy as np +from numpy import rad2deg + +import matplotlib +import matplotlib.animation as animation +import matplotlib.pyplot as plt + +from aerobench.util import get_state_names, StateIndex + +def get_script_path(filename=__file__): + '''get the path this script''' + return os.path.dirname(os.path.realpath(filename)) + +def init_plot(): + 'initialize plotting style' + + matplotlib.use('TkAgg') # set backend + + parent = get_script_path() + p = os.path.join(parent, 'bak_matplotlib.mlpstyle') + + plt.style.use(['bmh', p]) + +def plot_overhead(run_sim_result, waypoints=None, llc=None): + '''altitude over time plot from run_f16_sum result object + + note: call plt.show() afterwards to have plot show up + ''' + + init_plot() + + res = run_sim_result + fig = plt.figure(figsize=(7, 5)) + + ax = fig.add_subplot(1, 1, 1) + + full_states = res['states'] + + if llc is not None: + num_vars = len(get_state_names()) + llc.get_num_integrators() + num_aircraft = full_states[0, :].size // num_vars + else: + num_vars = full_states[0, :].size + num_aircraft = 1 + + for i in range(num_aircraft): + states = full_states[:, i*num_vars:(i+1)*num_vars] + + ys = states[:, StateIndex.POSN] # 9: n/s position (ft) + xs = states[:, StateIndex.POSE] # 10: e/w position (ft) + + ax.plot(xs, ys, '-') + + label = 'Start' if i == 0 else None + ax.plot([xs[0]], [ys[1]], 'k*', ms=8, label=label) + + if waypoints is not None: + xs = [wp[0] for wp in waypoints] + ys = [wp[1] for wp in waypoints] + + ax.plot(xs, ys, 'ro', label='Waypoints') + + ax.set_ylabel('North / South Position (ft)') + ax.set_xlabel('East / West Position (ft)') + + ax.set_title('Overhead Plot') + + ax.axis('equal') + + ax.legend() + + plt.tight_layout() + +def plot_attitude(run_sim_result, title='Attitude History', skip_yaw=True, figsize=(7, 5)): + 'plot a single variable over time' + + init_plot() + + res = run_sim_result + fig = plt.figure(figsize=figsize) + + ax = fig.add_subplot(1, 1, 1) + ax.ticklabel_format(useOffset=False) + + times = res['times'] + states = res['states'] + + indices = [StateIndex.PHI, StateIndex.THETA, StateIndex.PSI, StateIndex.P, StateIndex.Q, StateIndex.R] + labels = ['Roll (Phi)', 'Pitch (Theta)', 'Yaw (Psi)', 'Roll Rate (P)', 'Pitch Rate (Q)', 'Yaw Rate (R)'] + colors = ['r-', 'g-', 'b-', 'r--', 'g--', 'b--'] + + rad_to_deg_factor = 180 / math.pi + + for index, label, color in zip(indices, labels, colors): + if skip_yaw and index == StateIndex.PSI: + continue + + ys = states[:, index] # 11: altitude (ft) + + ax.plot(times, ys * rad_to_deg_factor, color, label=label) + + ax.set_ylabel('Attitudes & Rates (deg, deg/s)') + ax.set_xlabel('Time (sec)') + + if title is not None: + ax.set_title(title) + + ax.legend() + plt.tight_layout() + +def plot_outer_loop(run_sim_result, title='Outer Loop Controls'): + 'plot a single variable over time' + + init_plot() + + res = run_sim_result + assert 'u_list' in res, "Simulation must be run with extended_states=True" + fig = plt.figure(figsize=(7, 5)) + + ax = fig.add_subplot(1, 1, 1) + ax.ticklabel_format(useOffset=False) + + times = res['times'] + u_list = res['u_list'] + ps_list = res['ps_list'] + nz_list = res['Nz_list'] + ny_r_list = res['Ny_r_list'] + + # u is: throt, ele, ail, rud, Nz_ref, ps_ref, Ny_r_ref + # u_ref is: Nz, ps, Ny + r, throttle + ys_list = [] + + ys_list.append(nz_list) + ys_list.append([u[4] for u in u_list]) + + ys_list.append(ps_list) + ys_list.append([u[5] for u in u_list]) + + ys_list.append(ny_r_list) + ys_list.append([u[6] for u in u_list]) + + # throttle reference is not included... although it's just a small offset so probably less important + ys_list.append([u[0] for u in u_list]) + + labels = ['N_z', 'N_z,ref', 'P_s', 'P_s,ref', 'N_yr', 'N_yr,ref', 'Throttle'] + colors = ['r', 'r', 'lime', 'lime', 'b', 'b', 'c'] + + for i, (ys, label, color) in enumerate(zip(ys_list, labels, colors)): + lt = '-' if i % 2 == 0 else ':' + lw = 1 if i % 2 == 0 else 3 + + ax.plot(times, ys, lt, lw=lw, color=color, label=label) + + ax.set_ylabel('Autopilot (deg & percent)') + ax.set_xlabel('Time (sec)') + + if title is not None: + ax.set_title(title) + + ax.legend() + plt.tight_layout() + +def plot_inner_loop(run_sim_result, title='Inner Loop Controls'): + 'plot inner loop controls over time' + + init_plot() + + res = run_sim_result + assert 'u_list' in res, "Simulation must be run with extended_states=True" + fig = plt.figure(figsize=(7, 5)) + + ax = fig.add_subplot(1, 1, 1) + ax.ticklabel_format(useOffset=False) + + times = res['times'] + u_list = res['u_list'] + + # u is throt, ele, ail, rud, Nz_ref, ps_ref, Ny_r_ref + ys_list = [] + + rad_to_deg_factor = 180 / math.pi + + for i in range(4): + factor = 1.0 if i == 0 else rad_to_deg_factor + ys_list.append([u[i] * factor for u in u_list]) + + labels = ['Throttle', 'Elevator', 'Aileron', 'Rudder'] + colors = ['b-', 'r-', '#FFA500', 'm-'] + + for ys, label, color in zip(ys_list, labels, colors): + ax.plot(times, ys, color, label=label) + + ax.set_ylabel('Controls (deg & percent)') + ax.set_xlabel('Time (sec)') + + if title is not None: + ax.set_title(title) + + ax.legend() + plt.tight_layout() + +def plot_single(run_sim_result, state_name, title=None): + 'plot a single variable over time' + + init_plot() + + res = run_sim_result + fig = plt.figure(figsize=(7, 5)) + + ax = fig.add_subplot(1, 1, 1) + ax.ticklabel_format(useOffset=False) + + times = res['times'] + states = res['states'] + + index = get_state_names().index(state_name) + ys = states[:, index] # 11: altitude (ft) + + ax.plot(times, ys, '-') + + ax.set_ylabel(state_name) + ax.set_xlabel('Time') + + if title is not None: + ax.set_title(title) + + plt.tight_layout() + +def plot_altitude(run_sim_result): + '''altitude over time plot from run_f16_sum result object + + note: call plt.show() afterwards to have plot show up + ''' + + plot_single(run_sim_result, 'alt') + + +def plot2d(filename, times, plot_data_list): + '''plot state variables in 2d + + plot data list of is a list of (values_list, var_data), + where values_list is an 2-d array, the first is time step, the second is a state vector + and each var_data is a list of tuples: (state_index, label) + ''' + + num_plots = sum([len(var_data) for _, var_data in plot_data_list]) + + fig = plt.figure(figsize=(7, 5)) + + for plot_index in range(num_plots): + ax = fig.add_subplot(num_plots, 1, plot_index + 1) + ax.tick_params(axis='both', which='major', labelsize=16) + + sum_plots = 0 + states = None + state_var_data = None + + for values_list, var_data in plot_data_list: + if plot_index < sum_plots + len(var_data): + states = values_list + state_var_data = var_data + break + + sum_plots += len(var_data) + + state_index, label = state_var_data[plot_index - sum_plots] + + if state_index == 0 and isinstance(states[0], float): # state is just a single number + ys = states + else: + ys = [state[state_index] for state in states] + + ax.plot(times, ys, '-') + + ax.set_ylabel(label, fontsize=16) + + # last one gets an x axis label + if plot_index == num_plots - 1: + ax.set_xlabel('Time', fontsize=16) + + plt.tight_layout() + + if filename is not None: + plt.savefig(filename, bbox_inches='tight') + else: + plt.show() diff --git a/demo/F16/alt.png b/demo/F16/alt.png new file mode 100644 index 0000000000000000000000000000000000000000..2ce1015d7fd9d284e062ff5b3ae13a9a994d2c7f Binary files /dev/null and b/demo/F16/alt.png differ diff --git a/demo/F16/attitude.png b/demo/F16/attitude.png new file mode 100644 index 0000000000000000000000000000000000000000..e45ef4cfde0c619b767e5ffc6e142b6255aaa41f Binary files /dev/null and b/demo/F16/attitude.png differ diff --git a/demo/F16/inner_loop.png b/demo/F16/inner_loop.png new file mode 100644 index 0000000000000000000000000000000000000000..35a1669f956d0fa4cc17af0cd5c16927d0508bb7 Binary files /dev/null and b/demo/F16/inner_loop.png differ diff --git a/demo/F16/outer_loop.png b/demo/F16/outer_loop.png new file mode 100644 index 0000000000000000000000000000000000000000..7689d3aa137f10331f44a1c9db6faa2c75a9bec4 Binary files /dev/null and b/demo/F16/outer_loop.png differ diff --git a/demo/F16/overhead.png b/demo/F16/overhead.png new file mode 100644 index 0000000000000000000000000000000000000000..349b8404228524ce0096682069cef50f7cfab78c Binary files /dev/null and b/demo/F16/overhead.png differ diff --git a/demo/F16/waypoint_autopilot.py b/demo/F16/waypoint_autopilot.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9575ff51826eb2452f965af05e6cf155fa2c11 --- /dev/null +++ b/demo/F16/waypoint_autopilot.py @@ -0,0 +1,307 @@ +'''waypoint autopilot + +ported from matlab v2 +''' + +from math import pi, atan2, sqrt, sin, cos, asin + +import numpy as np + +from aerobench.highlevel.autopilot import Autopilot +from aerobench.util import StateIndex +from aerobench.lowlevel.low_level_controller import LowLevelController + +class WaypointAutopilot(Autopilot): + '''waypoint follower autopilot''' + + def __init__(self, waypoints, gain_str='old', stdout=False): + 'waypoints is a list of 3-tuples' + + self.stdout = stdout + self.waypoints = waypoints + self.waypoint_index = 0 + + # waypoint config + self.cfg_slant_range_threshold = 250 + + # default control when not waypoint tracking + self.cfg_u_ol_default = (0, 0, 0, 0.3) + + # control config + # Gains for speed control + self.cfg_k_vt = 0.25 + self.cfg_airspeed = 550 + + # Gains for altitude tracking + self.cfg_k_alt = 0.005 + self.cfg_k_h_dot = 0.02 + + # Gains for heading tracking + self.cfg_k_prop_psi = 5 + self.cfg_k_der_psi = 0.5 + + # Gains for roll tracking + self.cfg_k_prop_phi = 0.75 + self.cfg_k_der_phi = 0.5 + self.cfg_max_bank_deg = 65 # maximum bank angle setpoint + # v2 was 0.5, 0.9 + + # Ranges for Nz + self.cfg_max_nz_cmd = 4 + self.cfg_min_nz_cmd = -1 + + self.done_time = 0.0 + + llc = LowLevelController(gain_str=gain_str) + + Autopilot.__init__(self, 'Waypoint 1', llc=llc) + + def log(self, s): + 'print to terminal if stdout is true' + + if self.stdout: + print(s) + + def get_u_ref(self, _t, x_f16): + '''get the reference input signals''' + + if self.mode != "Done": + psi_cmd = self.get_waypoint_data(x_f16)[0] + + # Get desired roll angle given desired heading + phi_cmd = self.get_phi_to_track_heading(x_f16, psi_cmd) + ps_cmd = self.track_roll_angle(x_f16, phi_cmd) + + nz_cmd = self.track_altitude(x_f16) + throttle = self.track_airspeed(x_f16) + else: + # Waypoint Following complete: fly level. + throttle = self.track_airspeed(x_f16) + ps_cmd = self.track_roll_angle(x_f16, 0) + nz_cmd = self.track_altitude_wings_level(x_f16) + + # trim to limits + nz_cmd = max(self.cfg_min_nz_cmd, min(self.cfg_max_nz_cmd, nz_cmd)) + throttle = max(min(throttle, 1), 0) + + # Create reference vector + rv = [nz_cmd, ps_cmd, 0, throttle] + + return rv + + def track_altitude(self, x_f16): + 'get nz to track altitude, taking turning into account' + + h_cmd = self.waypoints[self.waypoint_index][2] + + h = x_f16[StateIndex.ALT] + phi = x_f16[StateIndex.PHI] + + # Calculate altitude error (positive => below target alt) + h_error = h_cmd - h + nz_alt = self.track_altitude_wings_level(x_f16) + nz_roll = get_nz_for_level_turn_ol(x_f16) + + if h_error > 0: + # Ascend wings level or banked + nz = nz_alt + nz_roll + elif abs(phi) < np.deg2rad(15): + # Descend wings (close enough to) level + nz = nz_alt + nz_roll + else: + # Descend in bank (no negative Gs) + nz = max(0, nz_alt + nz_roll) + + return nz + + def get_phi_to_track_heading(self, x_f16, psi_cmd): + 'get phi from psi_cmd' + + # PD Control on heading angle using phi_cmd as control + + # Pull out important variables for ease of use + psi = wrap_to_pi(x_f16[StateIndex.PSI]) + r = x_f16[StateIndex.R] + + # Calculate PD control + psi_err = wrap_to_pi(psi_cmd - psi) + + phi_cmd = psi_err * self.cfg_k_prop_psi - r * self.cfg_k_der_psi + + # Bound to acceptable bank angles: + max_bank_rad = np.deg2rad(self.cfg_max_bank_deg) + + phi_cmd = min(max(phi_cmd, -max_bank_rad), max_bank_rad) + + return phi_cmd + + def track_roll_angle(self, x_f16, phi_cmd): + 'get roll angle command (ps_cmd)' + + # PD control on roll angle using stability roll rate + + # Pull out important variables for ease of use + phi = x_f16[StateIndex.PHI] + p = x_f16[StateIndex.P] + + # Calculate PD control + ps = (phi_cmd-phi) * self.cfg_k_prop_phi - p * self.cfg_k_der_phi + + return ps + + def track_airspeed(self, x_f16): + 'get throttle command' + + vt_cmd = self.cfg_airspeed + + # Proportional control on airspeed using throttle + throttle = self.cfg_k_vt * (vt_cmd - x_f16[StateIndex.VT]) + + return throttle + + def track_altitude_wings_level(self, x_f16): + 'get nz to track altitude' + + i = self.waypoint_index if self.waypoint_index < len(self.waypoints) else -1 + + h_cmd = self.waypoints[i][2] + + vt = x_f16[StateIndex.VT] + h = x_f16[StateIndex.ALT] + + # Proportional-Derivative Control + h_error = h_cmd - h + gamma = get_path_angle(x_f16) + h_dot = vt * sin(gamma) # Calculated, not differentiated + + # Calculate Nz command + nz = self.cfg_k_alt*h_error - self.cfg_k_h_dot*h_dot + + return nz + + def is_finished(self, t, x_f16): + 'is the maneuver done?' + + rv = self.waypoint_index >= len(self.waypoints) and self.done_time + 5.0 < t + + return rv + + def advance_discrete_mode(self, t, x_f16): + ''' + advance the discrete state based on the current aircraft state. Returns True iff the discrete state + has changed. + ''' + + if self.waypoint_index < len(self.waypoints): + slant_range = self.get_waypoint_data(x_f16)[-1] + + if slant_range < self.cfg_slant_range_threshold: + self.waypoint_index += 1 + + if self.waypoint_index >= len(self.waypoints): + self.done_time = t + + premode = self.mode + + if self.waypoint_index >= len(self.waypoints): + self.mode = 'Done' + else: + self.mode = f'Waypoint {self.waypoint_index + 1}' + + rv = premode != self.mode + + if rv: + self.log(f"Waypoint transition {premode} -> {self.mode} at time {t}") + + return rv + + def get_waypoint_data(self, x_f16): + '''returns current waypoint data tuple based on the current waypoint: + + (heading, inclination, horiz_range, vert_range, slant_range) + + heading = heading to tgt, equivalent to psi (rad) + inclination = polar angle to tgt, equivalent to theta (rad) + horiz_range = horizontal range to tgt (ft) + vert_range = vertical range to tgt (ft) + slant_range = total range to tgt (ft) + ''' + + waypoint = self.waypoints[self.waypoint_index] + + e_pos = x_f16[StateIndex.POSE] + n_pos = x_f16[StateIndex.POSN] + alt = x_f16[StateIndex.ALT] + + delta = [waypoint[i] - [e_pos, n_pos, alt][i] for i in range(3)] + + _, inclination, slant_range = cart2sph(delta) + + heading = wrap_to_pi(pi/2 - atan2(delta[1], delta[0])) + + horiz_range = np.linalg.norm(delta[0:2]) + vert_range = np.linalg.norm(delta[2]) + + return heading, inclination, horiz_range, vert_range, slant_range + +def get_nz_for_level_turn_ol(x_f16): + 'get nz to do a level turn' + + # Pull g's to maintain altitude during bank based on trig + + # Calculate theta + phi = x_f16[StateIndex.PHI] + + if abs(phi): # if cos(phi) ~= 0, basically + nz = 1 / cos(phi) - 1 # Keeps plane at altitude + else: + nz = 0 + + return nz + +def get_path_angle(x_f16): + 'get the path angle gamma' + + alpha = x_f16[StateIndex.ALPHA] # AoA (rad) + beta = x_f16[StateIndex.BETA] # Sideslip (rad) + phi = x_f16[StateIndex.PHI] # Roll anle (rad) + theta = x_f16[StateIndex.THETA] # Pitch angle (rad) + + gamma = asin((cos(alpha)*sin(theta)- \ + sin(alpha)*cos(theta)*cos(phi))*cos(beta) - \ + (cos(theta)*sin(phi))*sin(beta)) + + return gamma + +def wrap_to_pi(psi_rad): + '''handle angle wrapping + + returns equivelent angle in range [-pi, pi] + ''' + + rv = psi_rad % (2 * pi) + + if rv > pi: + rv -= 2 * pi + + return rv + +def cart2sph(pt3d): + ''' + Cartesian to spherical coordinates + + returns az, elev, r + ''' + + x, y, z = pt3d + + h = sqrt(x*x + y*y) + r = sqrt(h*h + z*z) + + elev = atan2(z, h) + az = atan2(y, x) + + return az, elev, r + +if __name__ == '__main__': + print("Autopulot script not meant to be run directly.") diff --git a/demo/ball_bounces.py b/demo/ball/ball_bounces.py similarity index 96% rename from demo/ball_bounces.py rename to demo/ball/ball_bounces.py index 68dee14d04b00e120cd7adbeb2891ae149c83837..8d5c6caf621ae1c56cd259d43ef907228b92b2da 100644 --- a/demo/ball_bounces.py +++ b/demo/ball/ball_bounces.py @@ -34,7 +34,7 @@ class State: pass -def controller(ego: State, others: State): +def controller(ego: State): '''Computes the possible mode transitions''' output = copy.deepcopy(ego) '''TODO: Ego and output variable names should be flexible but @@ -72,7 +72,7 @@ if __name__ == "__main__": 5. genetating the simulation traces or computing the reachable states ''' bouncingBall = Scenario() - ball_controller = './ball_bounces.py' + ball_controller = './demo/ball/ball_bounces.py' myball1 = BallAgent('red-ball', file_name=ball_controller) myball2 = BallAgent('green-ball', file_name=ball_controller) bouncingBall.add_agent(myball1) @@ -93,7 +93,7 @@ if __name__ == "__main__": # this may be the cause for the VisibleDeprecationWarning # TODO: Longer term: We should initialize by writing expressions like "-2 \leq myball1.x \leq 5" # "-2 \leq myball1.x + myball2.x \leq 5" - traces = bouncingBall.simulate(40, 0.1) + traces = bouncingBall.simulate(40,0.01) # TODO: There should be a print({traces}) function fig = go.Figure() fig = simulation_anime_trail( diff --git a/demo/ball_bounces_dev.py b/demo/ball_bounces_dev.py index b21960819091a513157506143f98e238c2a1b62a..e8f5ea255dd0a209f14c04875b0dbbd466327954 100644 --- a/demo/ball_bounces_dev.py +++ b/demo/ball_bounces_dev.py @@ -1,46 +1,44 @@ from enum import Enum, auto - +import copy class BallTypeMode(Enum): TYPE1 = auto() TYPE2 = auto() + class BallMode(Enum): Normal = auto() - + + class State: - x:float + x: float y = 0.0 vx = 0.0 vy = 0.0 mode: BallMode type: BallTypeMode - def __init__(self, x, y, vx, vy, ball_mode:BallMode, type: BallTypeMode): + + def __init__(self, x, y, vx, vy, ball_mode: BallMode, type: BallTypeMode): pass def controller(ego:State, other: State): output = copy.deepcopy(ego) - if ego.x<0: + if ego.x < 0: output.vx = -ego.vx - output.x=0 - if ego.y<0: + output.x = 0 + if ego.y < 0: output.vy = -ego.vy - output.y=0 - if ego.x>20: + output.y = 0 + if ego.x > 20: output.vx = -ego.vx - output.x=20 - if ego.y>20: + output.x = 20 + if ego.y > 20: output.vy = -ego.vy - output.y=20 - def abs_diff(a, b): - if a < b: - r = b - a - else: - r = a - b - return r - def dist(a, b): - return abs_diff(a.x, b.x) + abs_diff(a.y, b.y) - assert not (dist(ego, other) < 5 and ego.x < other.x), "collision" + output.y = 20 + + def close(a, b): + return a.x-b.x<5 and a.x-b.x>-5 and a.y-b.y<5 and a.y-b.y>-5 + assert not (close(ego, other) and ego.x < other.x), "collision" return output from dryvr_plus_plus.example import BallAgent @@ -70,7 +68,8 @@ if __name__ == "__main__": ] ) traces = bouncingBall.simulate(10, 0.01) + traces.dump('./output.json') + traces = AnalysisTree.load('./output.json') fig = go.Figure() fig = simulation_tree(traces, fig=fig) fig.show() - diff --git a/demo/controller/example_controller10.py b/demo/controller/example_controller10.py deleted file mode 100644 index b6e3550eef52d8eb90da501cb975c728379ec4cf..0000000000000000000000000000000000000000 --- a/demo/controller/example_controller10.py +++ /dev/null @@ -1,77 +0,0 @@ -from enum import Enum, auto -import copy -from dryvr_plus_plus.map import LaneMap - - -class VehicleMode(Enum): - Normal = auto() - SwitchLeft = auto() - SwitchRight = auto() - Brake = auto() - Accelerate = auto() - - -class LaneMode(Enum): - Lane0 = auto() - Lane1 = auto() - Lane2 = auto() - - -class LaneObjectMode(Enum): - Vehicle = auto() - Ped = auto() # Pedestrians - Sign = auto() # Signs, stop signs, merge, yield etc. - Signal = auto() # Traffic lights - Obstacle = auto() # Static (to road/lane) obstacles - - -class State: - x = 0.0 - y = 0.0 - theta = 0.0 - v = 0.0 - vehicle_mode: VehicleMode = VehicleMode.Normal - lane_mode: LaneMode = LaneMode.Lane0 - type: LaneObjectMode - - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): - self.data = [] - - -def controller(ego: State, other: State, sign: State, lane_map: LaneMap): - output = copy.deepcopy(ego) - if ego.vehicle_mode == VehicleMode.Normal: - if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \ - and ego.lane_mode == other.lane_mode: - if lane_map.has_left(ego.lane_mode): - output.vehicle_mode = VehicleMode.SwitchLeft - if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \ - and ego.lane_mode == other.lane_mode: - if lane_map.has_right(ego.lane_mode): - output.vehicle_mode = VehicleMode.SwitchRight - # if ego.lane_mode != other.lane_mode: - # output.vehicle_mode = VehicleMode.Accelerate - # if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 5: - # output.vehicle_mode = VehicleMode.Accelerate - if lane_map.get_speed_limit(ego.lane_mode) > 1: - output.vehicle_mode = VehicleMode.Accelerate - - if ego.vehicle_mode == VehicleMode.SwitchLeft: - if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5: - output.vehicle_mode = VehicleMode.Normal - output.lane_mode = lane_map.left_lane(ego.lane_mode) - if ego.vehicle_mode == VehicleMode.SwitchRight: - if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) <= -2.5: - output.vehicle_mode = VehicleMode.Normal - output.lane_mode = lane_map.right_lane(ego.lane_mode) - if ego.vehicle_mode == VehicleMode.Accelerate: - if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \ - and ego.lane_mode == other.lane_mode: - output.vehicle_mode = VehicleMode.Normal - # if lane_map.get_speed_limit(ego.lane_mode, [ego.x, ego.y]) <= ego.v: - # output.vehicle_mode = VehicleMode.Normal - - return output diff --git a/demo/demo1.py b/demo/demo1.py deleted file mode 100644 index f491c41c7edb7b788063e51587afa15774177004..0000000000000000000000000000000000000000 --- a/demo/demo1.py +++ /dev/null @@ -1,58 +0,0 @@ -from dryvr_plus_plus.example import CarAgent, NPCAgent, SimpleMap2, FakeSensor2 -from dryvr_plus_plus import Scenario -from enum import Enum, auto -from dryvr_plus_plus.plotter.plotter2D import * - - -class VehicleMode(Enum): - Normal = auto() - SwitchLeft = auto() - SwitchRight = auto() - Brake = auto() - - -class LaneMode(Enum): - Lane0 = auto() - Lane1 = auto() - Lane2 = auto() - - -class State: - x = 0.0 - y = 0.0 - theta = 0.0 - v = 0.0 - vehicle_mode: VehicleMode = VehicleMode.Normal - lane_mode: LaneMode = LaneMode.Lane0 - - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): - self.data = [] - - -if __name__ == "__main__": - input_code_name = './demo/controller/example_controller1.py' - scenario = Scenario() - - car = NPCAgent('car1') - scenario.add_agent(car) - car = CarAgent('car2', file_name=input_code_name) - scenario.add_agent(car) - tmp_map = SimpleMap2() - scenario.set_map(tmp_map) - scenario.set_sensor(FakeSensor2()) - scenario.set_init( - [ - [[10.0, 0, 0, 0.5], [10.0, 0, 0, 0.5]], - [[5.0, -0.2, 0, 2.0], [6.0, 0.2, 0, 3.0]], - ], - [ - (VehicleMode.Normal, LaneMode.Lane1), - (VehicleMode.Normal, LaneMode.Lane1), - ] - ) - - traces = scenario.simulate(10, 0.01) - fig = go.Figure() - fig = simulation_anime(traces, tmp_map, fig, 1, 2, - 'lines', 'trace', print_dim_list=[1, 2]) - fig.show() diff --git a/demo/demo10.py b/demo/demo10.py deleted file mode 100644 index bd7571f688187d76567aefc6827d73269c25e078..0000000000000000000000000000000000000000 --- a/demo/demo10.py +++ /dev/null @@ -1,80 +0,0 @@ -from dryvr_plus_plus.example import CarAgent, SimpleMap3_v2, FakeSensor2 -from dryvr_plus_plus import Scenario -from dryvr_plus_plus.plotter.plotter2D import * - -from enum import Enum, auto -import plotly.graph_objects as go - - -class VehicleMode(Enum): - Normal = auto() - SwitchLeft = auto() - SwitchRight = auto() - Brake = auto() - Accelerate = auto() - - -class LaneMode(Enum): - Lane0 = auto() - Lane1 = auto() - Lane2 = auto() - - -class State: - x = 0.0 - y = 0.0 - theta = 0.0 - v = 0.0 - vehicle_mode: VehicleMode = VehicleMode.Normal - lane_mode: LaneMode = LaneMode.Lane0 - - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): - self.data = [] - - -if __name__ == "__main__": - input_code_name = './demo/controller/example_controller10.py' - scenario = Scenario() - - car = CarAgent('car1', file_name=input_code_name) - scenario.add_agent(car) - car = CarAgent('car2', file_name=input_code_name) - scenario.add_agent(car) - tmp_map = SimpleMap3_v2() - scenario.set_map(tmp_map) - scenario.set_sensor(FakeSensor2()) - scenario.set_init( - [ - [[0, -0.2, 0, 1.0], [0.1, 0.2, 0, 1.0]], - [[10, 0, 0, 0.5], [10, 0, 0, 0.5]], - ], - [ - (VehicleMode.Normal, LaneMode.Lane1), - (VehicleMode.Normal, LaneMode.Lane1), - ] - ) - # traces = scenario.verify(30) - # # fig = go.Figure() - # # fig = plotly_reachtube_tree_v2(traces, 'car1', 1, [2], 'blue', fig) - # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig) - # # fig.show() - # fig = go.Figure() - # fig = generate_reachtube_anime(traces, tmp_map, fig) - # # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig) - # fig.show() - # fig = plt.figure(2) - # fig = plot_map(tmp_map, 'g', fig) - # # fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig) - # # fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig) - # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) - # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) - # plt.show() - # # fig1 = plt.figure(2) - # fig = generate_simulation_anime(traces, tmp_map, fig) - # plt.show() - - traces = scenario.simulate(25, 0.05) - fig = go.Figure() - fig = simulation_anime( - traces, tmp_map, fig, 1, 2, 'detailed', 'trace', print_dim_list=[1, 2]) - fig.show() diff --git a/demo/dryvr_demo/rendezvous_controller.py b/demo/dryvr_demo/rendezvous_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1e310f79229d69582257eb8918ce685d86272e --- /dev/null +++ b/demo/dryvr_demo/rendezvous_controller.py @@ -0,0 +1,37 @@ +from enum import Enum, auto +import copy + + +class CraftMode(Enum): + ProxA = auto() + ProxB = auto() + Passive = auto() + + +class State: + xp = 0.0 + yp = 0.0 + xd = 0.0 + yd = 0.0 + total_time = 0.0 + cycle_time = 0.0 + craft_mode: CraftMode = CraftMode.ProxA + + def __init__(self, xp, yp, xd, yd, total_time, cycle_time, craft_mode: CraftMode): + pass + + +def controller(ego: State): + output = copy.deepcopy(ego) + if ego.craft_mode == CraftMode.ProxA: + if ego.yp >= -100 and ego.xp+ego.yp >= -141.1 and ego.xp >= -100 and ego.yp-ego.xp <= 141.1 and ego.yp <= 100 and ego.xp+ego.yp <= 141.1 and ego.xp <= 100 and ego.yp-ego.xp >= -141.1: + output.craft_mode = CraftMode.ProxB + if ego.cycle_time >= 120: + output.craft_mode = CraftMode.Passive + output.cycle_time = 0.0 + + if ego.craft_mode == CraftMode.ProxB: + if ego.cycle_time >= 120: + output.craft_mode = CraftMode.Passive + output.cycle_time = 0.0 + return output diff --git a/demo/dryvr_demo/rendezvous_demo.py b/demo/dryvr_demo/rendezvous_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..78d87588a97d990d6d2dde4aa07535552ba168c3 --- /dev/null +++ b/demo/dryvr_demo/rendezvous_demo.py @@ -0,0 +1,32 @@ +from dryvr_plus_plus.example.example_agent.origin_agent import craft_agent +from dryvr_plus_plus import Scenario +from dryvr_plus_plus.plotter.plotter2D import * + +import plotly.graph_objects as go +from enum import Enum, auto + +class CraftMode(Enum): + ProxA = auto() + ProxB = auto() + Passive = auto() + +if __name__ == "__main__": + input_code_name = './demo/dryvr_demo/rendezvous_controller.py' + scenario = Scenario() + + car = craft_agent('test', file_name=input_code_name) + scenario.add_agent(car) + # modify mode list input + scenario.set_init( + [ + [[-925, -425, 0, 0, 0, 0], [-875, -375, 0, 0, 0, 0]], + ], + [ + tuple([CraftMode.ProxA]), + ] + ) + traces = scenario.verify(200, 1) + fig = go.Figure() + fig = reachtube_tree(traces, None, fig, 1, 2, + 'lines', 'trace', print_dim_list=[1, 2]) + fig.show() diff --git a/demo/dryvr_demo/thermo_controller.py b/demo/dryvr_demo/thermo_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..59af1db6e804d9ebd37a6ef240dc7d74d4e569e7 --- /dev/null +++ b/demo/dryvr_demo/thermo_controller.py @@ -0,0 +1,30 @@ +from enum import Enum, auto +import copy + + +class ThermoMode(Enum): + ON = auto() + OFF = auto() + + +class State: + temp = 0.0 + total_time = 0.0 + cycle_time = 0.0 + thermo_mode: ThermoMode = ThermoMode.ON + + def __init__(self, temp, total_time, cycle_time, thermo_mode: ThermoMode): + pass + + +def controller(ego: State): + output = copy.deepcopy(ego) + if ego.thermo_mode == ThermoMode.ON: + if ego.cycle_time >= 1.0 and ego.cycle_time < 1.1: + output.thermo_mode = ThermoMode.OFF + output.cycle_time = 0.0 + if ego.thermo_mode == ThermoMode.OFF: + if ego.cycle_time >= 1.0 and ego.cycle_time < 1.1: + output.thermo_mode = ThermoMode.ON + output.cycle_time = 0.0 + return output diff --git a/demo/dryvr_demo/thermo_demo.py b/demo/dryvr_demo/thermo_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc8f8fb3aeed4c8ffd0528913a0a4dad75ba10b --- /dev/null +++ b/demo/dryvr_demo/thermo_demo.py @@ -0,0 +1,31 @@ +from dryvr_plus_plus.example.example_agent.origin_agent import thermo_agent +from dryvr_plus_plus import Scenario +from dryvr_plus_plus.plotter.plotter2D import * + +import plotly.graph_objects as go +from enum import Enum, auto + +class ThermoMode(Enum): + ON = auto() + OFF = auto() + +if __name__ == "__main__": + input_code_name = './demo/dryvr_demo/thermo_controller.py' + scenario = Scenario() + + car = thermo_agent('test', file_name=input_code_name) + scenario.add_agent(car) + # modify mode list input + scenario.set_init( + [ + [[75.0, 0.0, 0.0], [76, 0.0, 0.0]], + ], + [ + tuple([ThermoMode.ON]), + ] + ) + traces = scenario.verify(3.5, 0.05) + fig = go.Figure() + fig = reachtube_tree(traces, None, fig, 2, 1, + 'lines', 'trace', print_dim_list=[2, 1]) + fig.show() diff --git a/demo/dryvr_demo/thermo_demo2.py b/demo/dryvr_demo/thermo_demo2.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0e6356ee97f7645a5eb9fc9f1692a42b7ced63 --- /dev/null +++ b/demo/dryvr_demo/thermo_demo2.py @@ -0,0 +1,37 @@ +from dryvr_plus_plus.example.example_agent.origin_agent import thermo_agent +from dryvr_plus_plus import Scenario +from dryvr_plus_plus.plotter.plotter2D import * + +import plotly.graph_objects as go +from enum import Enum, auto + +class ThermoMode(Enum): + ON = auto() + OFF = auto() + +if __name__ == "__main__": + input_code_name = './demo/dryvr_demo/thermo_controller.py' + scenario = Scenario() + + car = thermo_agent('test', file_name=input_code_name) + scenario.add_agent(car) + car = thermo_agent('test2', file_name=input_code_name) + scenario.add_agent(car) + # tmp_map = SimpleMap3() + # scenario.set_map(tmp_map) + # modify mode list input + scenario.set_init( + [ + [[75.0, 0.0, 0.0], [75.0, 0.0, 0.0]], + [[76.0, 0.0, 0.0], [76.0, 0.0, 0.0]], + ], + [ + tuple([ThermoMode.ON]), + tuple([ThermoMode.ON]), + ] + ) + traces = scenario.simulate(3.5, 0.05) + fig = go.Figure() + fig = simulation_tree(traces, None, fig, 2, 1, + 'lines', 'trace', print_dim_list=[2, 1]) + fig.show() diff --git a/demo/dryvr_demo/vanderpol_controller.py b/demo/dryvr_demo/vanderpol_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6de9525f697868dfdb528ca662909e7cf45819 --- /dev/null +++ b/demo/dryvr_demo/vanderpol_controller.py @@ -0,0 +1,18 @@ +from enum import Enum, auto +import copy + +class AgentMode(Enum): + Default = auto() + +class State: + x = 0.0 + y = 0.0 + agent_mode: AgentMode = AgentMode.Default + + def __init__(self, x, y, agent_mode: AgentMode): + pass + +def controller(ego: State): + output = copy.deepcopy(ego) + + return output diff --git a/demo/dryvr_demo/vanderpol_demo.py b/demo/dryvr_demo/vanderpol_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..49fa5bcae16fdd7c42df0bf161f6b363b9f0b578 --- /dev/null +++ b/demo/dryvr_demo/vanderpol_demo.py @@ -0,0 +1,37 @@ +from dryvr_plus_plus.example.example_agent.origin_agent import vanderpol_agent +from dryvr_plus_plus import Scenario +from dryvr_plus_plus.plotter.plotter2D import * + +import plotly.graph_objects as go +from enum import Enum, auto + + +class AgentMode(Enum): + Default = auto() + + +if __name__ == "__main__": + input_code_name = './demo/dryvr_demo/vanderpol_controller.py' + scenario = Scenario() + + car = vanderpol_agent('car1', file_name=input_code_name) + scenario.add_agent(car) + # car = vanderpol_agent('car2', file_name=input_code_name) + # scenario.add_agent(car) + # scenario.set_sensor(FakeSensor2()) + # modify mode list input + scenario.set_init( + [ + [[1.25, 2.25], [1.25, 2.25]], + # [[1.55, 2.35], [1.55, 2.35]] + ], + [ + tuple([AgentMode.Default]), + # tuple([AgentMode.Default]), + ] + ) + traces = scenario.simulate(7, 0.05) + fig = go.Figure() + fig = simulation_tree(traces, None, fig, 1, 2, + 'lines', 'trace', print_dim_list=[1, 2]) + fig.show() diff --git a/demo/dryvr_demo/vanderpol_demo2.py b/demo/dryvr_demo/vanderpol_demo2.py new file mode 100644 index 0000000000000000000000000000000000000000..9e65011d97e54548efa0cbd4577ce45fb1607f04 --- /dev/null +++ b/demo/dryvr_demo/vanderpol_demo2.py @@ -0,0 +1,30 @@ +from dryvr_plus_plus.example.example_agent.origin_agent import vanderpol_agent +from dryvr_plus_plus import Scenario +from dryvr_plus_plus.plotter.plotter2D import * + +import plotly.graph_objects as go +from enum import Enum, auto + +class AgentMode(Enum): + Default = auto() + +if __name__ == "__main__": + input_code_name = './demo/dryvr_demo/vanderpol_controller.py' + scenario = Scenario() + + car = vanderpol_agent('car1', file_name=input_code_name) + scenario.add_agent(car) + # modify mode list input + scenario.set_init( + [ + [[1.25, 2.25], [1.55, 2.35]], + ], + [ + tuple([AgentMode.Default]), + ] + ) + traces = scenario.verify(7, 0.05) + fig = go.Figure() + fig = reachtube_tree(traces, None, fig, 1, 2, + 'lines', 'trace', print_dim_list=[1, 2]) + fig.show() diff --git a/demo/quadrotor/quadrotor_controller.py b/demo/quadrotor/quadrotor_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..46bcb5e4fe5aca3de9a41fc922bbd32107be7218 --- /dev/null +++ b/demo/quadrotor/quadrotor_controller.py @@ -0,0 +1,47 @@ +from enum import Enum, auto +import copy + + +class CraftMode(Enum): + Follow_Waypoint = auto() + + +class State: + x = 0.0 + y = 0.0 + z = 0.0 + vx = 0.0 + vy = 0.0 + vz = 0.0 + craft_mode: CraftMode = CraftMode.Follow_Waypoint + waypoint_index: int = 0 + done_flag = 0.0 # indicate if the quad rotor reach the waypoint + + def __init__(self, x, y, z, vx, vy, vz, waypoint_index, done_flag, craft_mode): + pass + + +def controller(ego: State): + output = copy.deepcopy(ego) + if ego.craft_mode == CraftMode.Follow_Waypoint: + if ego.waypoint_index == 0 and ego.done_flag > 0: + output.craft_mode = CraftMode.Follow_Waypoint + output.waypoint_index = 1 + output.done_flag = 0 + if ego.waypoint_index == 1 and ego.done_flag > 0: + output.craft_mode = CraftMode.Follow_Waypoint + output.waypoint_index = 4 + output.done_flag = 0 + if ego.waypoint_index == 1 and ego.done_flag > 0: + output.craft_mode = CraftMode.Follow_Waypoint + output.waypoint_index = 2 + output.done_flag = 0 + if ego.waypoint_index == 2 and ego.done_flag > 0: + output.craft_mode = CraftMode.Follow_Waypoint + output.waypoint_index = 3 + output.done_flag = 0 + if ego.waypoint_index == 4 and ego.done_flag > 0: + output.craft_mode = CraftMode.Follow_Waypoint + output.waypoint_index = 5 + output.done_flag = 0 + return output diff --git a/demo/quadrotor/quadrotor_demo.py b/demo/quadrotor/quadrotor_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..f7196ed8b37e80b45737d7f2d4e0f081267b167f --- /dev/null +++ b/demo/quadrotor/quadrotor_demo.py @@ -0,0 +1,59 @@ +from regex import B +from dryvr_plus_plus.example import QuadrotorAgent +from dryvr_plus_plus import Scenario +from dryvr_plus_plus.example import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6 +from dryvr_plus_plus.plotter.plotter2D import * +import os +import json +import plotly.graph_objects as go +from enum import Enum, auto + + +class CraftMode(Enum): + Follow_Waypoint = auto() + + +if __name__ == "__main__": + input_code_name = './demo/quadrotor/quadrotor_controller.py' + scenario = Scenario() + + path = os.path.abspath(__file__) + path = path.replace('quadrotor_demo.py', 'test.json') + # print(path) + with open(path, 'r') as f: + prarms = json.load(f) + waypoints = [mode[1] for mode in prarms["agents"][0]["mode_list"]] + guard_boxes = [guard[1] for guard in prarms["agents"][0]["guards"]] + time_limits = prarms["agents"][0]["timeHorizons"] + # print(waypoints) + # print(guard_boxes) + quadrotor = QuadrotorAgent( + 'test', file_name=input_code_name, waypoints=waypoints, boxes=guard_boxes, time_limits=time_limits) + scenario.add_agent(quadrotor) + # scenario.set_sensor(QuadrotorSensor()) + # modify mode list input + scenario.set_init( + [ + [[2.75, -0.25, -0.1, 0, 0, 0, 0, 0], [3, 0, 0, 0.1, 0.1, 0.1, 0, 0]], + ], + [ + tuple([CraftMode.Follow_Waypoint]), + ] + ) + + traces = scenario.simulate(200, 0.2) + fig = go.Figure() + fig = simulation_tree(traces, None, fig, 1, 2, + 'lines', 'trace', print_dim_list=[0, 1, 2]) + fig = fig.add_trace(go.Scatter( + x=[3, 5, 5, 2, 2, 8, 8], y=[0, 0, 3, 3, 6, 3, 0], text=[0, 1, 2, 3, 4, 5, 6], mode='markers', marker={'color': 'black'})) + fig.show() + + + traces = scenario.verify(200, 0.2) + fig = go.Figure() + fig = reachtube_tree(traces, None, fig, 1, 2, + 'lines', 'trace', print_dim_list=[0, 1, 2]) + fig = fig.add_trace(go.Scatter( + x=[3, 5, 5, 2, 2, 8, 8], y=[0, 0, 3, 3, 6, 3, 0], text=[0, 1, 2, 3, 4, 5, 6], mode='markers', marker={'color': 'black'})) + fig.show() diff --git a/demo/quadrotor/test.json b/demo/quadrotor/test.json new file mode 100644 index 0000000000000000000000000000000000000000..76c9f53a4053a0bfade65d2675266fd1be9bd527 --- /dev/null +++ b/demo/quadrotor/test.json @@ -0,0 +1,34 @@ +{ + "grid_resolution": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + "symmetry_level": "0", + "reachability_engine": "default", + "refine_threshold": 10, + "agents": [{ + "initialModeID": 0, + "segLength": -1, + "edge_list": [[0,1],[1,2],[2,3],[1,4],[4,5]], + "variables": ["x", "y", "z", "vx", "vy", "vz"], + "guards": [ + ["Box", [[4.6, -0.4, -1, -1, -100, -100], [5.4, 0.4, 1, 1, 100, 100]]], + ["Box", [[4.6, 2.6, -1, -100, -100, -100], [5.4, 3.4, 1, 100, 100, 100]]], + ["Box", [[1.6, 2.6, -1, -100, -100, -100], [2.4, 3.4, 1, 100, 100, 100]]], + ["Box", [[4.6, 2.6, -1, -100, -100, -100], [5.4, 3.4, 1, 100, 100, 100]]], + ["Box", [[7.6, 2.6, -1, -100, -100, -100], [8.4, 3.4, 1, 100, 100, 100]]] + ], + "timeHorizons": [20,20,20,20,20,20], + "directory": "examples/models/NNquadrotor_new_code_TR", + "mode_list": [ + ["follow_waypoint", [3, 0, 0, 5, 0, 0]], + ["follow_waypoint", [5, 0, 0, 5, 3, 0]], + ["follow_waypoint", [5, 3, 0, 2, 3, 0]], + ["follow_waypoint", [2, 3, 0, 2, 6, 0]], + ["follow_waypoint", [5, 3, 0, 8, 3, 0]], + ["follow_waypoint", [8, 3, 0, 8, 0, 0]] + ], + "initialSet": ["Box", [[2.75, -0.25, -0.1, 0, 0, 0], [3, 0,0, 0.1, 0.1, 0.1]]] + }], + "time_step": 0.1, + "unsafeSet": [ + ["Box",[[1.5,2.8,-100,-100,-100,-100],[1.7,3.2,100,100,100,100]]] + ] +} \ No newline at end of file diff --git a/demo/controller/example_controller1.py b/demo/vehicle/controller/example_controller1.py similarity index 61% rename from demo/controller/example_controller1.py rename to demo/vehicle/controller/example_controller1.py index 507f90c5ce0e1631ebae433d3b11f7802666751d..bc3c11401fe4fbd7c648d09614ef97a091014f9a 100644 --- a/demo/controller/example_controller1.py +++ b/demo/vehicle/controller/example_controller1.py @@ -1,22 +1,17 @@ from enum import Enum, auto import copy -from dryvr_plus_plus.map import LaneMap - - class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() SwitchRight = auto() Brake = auto() - class LaneMode(Enum): Lane0 = auto() Lane1 = auto() Lane2 = auto() - class State: x = 0.0 y = 0.0 @@ -28,17 +23,17 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): self.data = [] - -def controller(ego: State, other: State, lane_map: LaneMap): +def controller(ego:State, other:State, lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: - if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 0 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 3 \ - and ego.lane_mode == other.lane_mode: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 0 \ + and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 3 \ + and ego.lane_mode == other.lane_mode: output.vehicle_mode = VehicleMode.Brake elif ego.vehicle_mode == VehicleMode.Brake: - if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ - or ego.lane_mode != other.lane_mode: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 10 \ + or ego.lane_mode != other.lane_mode: output.vehicle_mode = VehicleMode.Normal return output + diff --git a/demo/vehicle/controller/example_controller11.py b/demo/vehicle/controller/example_controller11.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5e3857998a93203d764d2354b6df6a207c0f60 --- /dev/null +++ b/demo/vehicle/controller/example_controller11.py @@ -0,0 +1,39 @@ +from enum import Enum, auto +import copy + + +class VehicleMode(Enum): + Normal = auto() + SwitchLeft = auto() + SwitchRight = auto() + Brake = auto() + + +class LaneMode(Enum): + Lane0 = auto() + Lane1 = auto() + Lane2 = auto() + + +class State: + x = 0.0 + y = 0.0 + theta = 0.0 + v = 0.0 + vehicle_mode: VehicleMode = VehicleMode.Normal + lane_mode: LaneMode = LaneMode.Lane0 + + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): + self.data = [] + + +def controller(ego: State, other: State, lane_map): + output = copy.deepcopy(ego) + if ego.vehicle_mode == VehicleMode.Normal: + if ego.x > 20 and ego.x < 25: + output.vehicle_mode = VehicleMode.Brake + elif ego.vehicle_mode == VehicleMode.Brake: + if ego.x >= 25: + output.vehicle_mode = VehicleMode.Normal + + return output diff --git a/demo/controller/example_controller2.py b/demo/vehicle/controller/example_controller2.py similarity index 100% rename from demo/controller/example_controller2.py rename to demo/vehicle/controller/example_controller2.py diff --git a/demo/controller/example_controller3.py b/demo/vehicle/controller/example_controller3.py similarity index 100% rename from demo/controller/example_controller3.py rename to demo/vehicle/controller/example_controller3.py diff --git a/demo/controller/example_controller4.py b/demo/vehicle/controller/example_controller4.py similarity index 100% rename from demo/controller/example_controller4.py rename to demo/vehicle/controller/example_controller4.py diff --git a/demo/controller/example_controller5.py b/demo/vehicle/controller/example_controller5.py similarity index 100% rename from demo/controller/example_controller5.py rename to demo/vehicle/controller/example_controller5.py diff --git a/demo/controller/example_controller6.py b/demo/vehicle/controller/example_controller6.py similarity index 100% rename from demo/controller/example_controller6.py rename to demo/vehicle/controller/example_controller6.py diff --git a/demo/controller/example_controller7.py b/demo/vehicle/controller/example_controller7.py similarity index 100% rename from demo/controller/example_controller7.py rename to demo/vehicle/controller/example_controller7.py diff --git a/demo/controller/example_controller8.py b/demo/vehicle/controller/example_controller8.py similarity index 100% rename from demo/controller/example_controller8.py rename to demo/vehicle/controller/example_controller8.py diff --git a/demo/controller/example_controller9.py b/demo/vehicle/controller/example_controller9.py similarity index 100% rename from demo/controller/example_controller9.py rename to demo/vehicle/controller/example_controller9.py diff --git a/demo/controller/example_two_car_sign_lane_switch.py b/demo/vehicle/controller/example_two_car_sign_lane_switch.py similarity index 73% rename from demo/controller/example_two_car_sign_lane_switch.py rename to demo/vehicle/controller/example_two_car_sign_lane_switch.py index 6f7d9798ce49f00c6dda9386d0923eb7ea09163f..79c5e324c6d4678d4c08d506ef655dd886ac01fe 100644 --- a/demo/controller/example_two_car_sign_lane_switch.py +++ b/demo/vehicle/controller/example_two_car_sign_lane_switch.py @@ -16,7 +16,6 @@ class LaneObjectMode(Enum): Signal = auto() # Traffic lights Obstacle = auto() # Static (to road/lane) obstacles - class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() @@ -29,7 +28,6 @@ class LaneMode(Enum): Lane1 = auto() Lane2 = auto() - class State: x: float y: float @@ -50,25 +48,24 @@ class State: # self.lane_mode = lane_mode # self.obj_mode = obj_mode - def controller(ego: State, other: State, sign: State, lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: if sign.type == LaneObjectMode.Obstacle and sign.x - ego.x < 3 and sign.x - ego.x > 0 and ego.lane_mode == sign.lane_mode: output.vehicle_mode = VehicleMode.SwitchLeft return output - if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \ - and ego.lane_mode == other.lane_mode: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \ + and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 5 \ + and ego.lane_mode == other.lane_mode: if lane_map.has_left(ego.lane_mode): output.vehicle_mode = VehicleMode.SwitchLeft - if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \ - and ego.lane_mode == other.lane_mode: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \ + and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 5 \ + and ego.lane_mode == other.lane_mode: if lane_map.has_right(ego.lane_mode): output.vehicle_mode = VehicleMode.SwitchRight if ego.vehicle_mode == VehicleMode.SwitchLeft: - if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5: + if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5: output.vehicle_mode = VehicleMode.Normal output.lane_mode = lane_map.left_lane(ego.lane_mode) if ego.vehicle_mode == VehicleMode.SwitchRight: @@ -79,6 +76,15 @@ def controller(ego: State, other: State, sign: State, lane_map): return output +from dryvr_plus_plus.example.example_agent.car_agent import CarAgent +from dryvr_plus_plus.example.example_agent.sign_agent import SignAgent +from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario +from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap3 +from dryvr_plus_plus.plotter.plotter2D import plot_reachtube_tree, plot_simulation_tree +from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2 + +import matplotlib.pyplot as plt + if __name__ == "__main__": import sys input_code_name = sys.argv[0] @@ -93,9 +99,9 @@ if __name__ == "__main__": scenario.set_sensor(FakeSensor2()) scenario.set_init( [ - [[0, -0.2, 0, 1.0], [0.2, 0.2, 0, 1.0]], - [[10, 0, 0, 0.5], [10, 0, 0, 0.5]], - [[20, 0, 0, 0], [20, 0, 0, 0]], + [[0, -0.2, 0, 1.0],[0.2, 0.2, 0, 1.0]], + [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], + [[20, 0, 0, 0],[20, 0, 0, 0]], ], [ (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), diff --git a/demo/demo2.py b/demo/vehicle/demo2.py similarity index 56% rename from demo/demo2.py rename to demo/vehicle/demo2.py index 6cad100dd03191cdc1568f29baabe3b4b014f945..eae100aa0c4826d802db62f9ccf3f7b5c31a0f1f 100644 --- a/demo/demo2.py +++ b/demo/vehicle/demo2.py @@ -3,20 +3,17 @@ from dryvr_plus_plus import Scenario from dryvr_plus_plus.plotter.plotter2D import * from enum import Enum, auto - class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() SwitchRight = auto() Brake = auto() - class LaneMode(Enum): Lane0 = auto() Lane1 = auto() Lane2 = auto() - class State: x = 0.0 y = 0.0 @@ -28,9 +25,8 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): self.data = [] - if __name__ == "__main__": - input_code_name = './demo/controller/example_controller2.py' + input_code_name = './demo/vehicle/controller/example_controller2.py' scenario = Scenario() car = CarAgent('car1', file_name=input_code_name) @@ -51,10 +47,29 @@ if __name__ == "__main__": ] ) - traces = scenario.simulate(70, 0.05) - # traces = scenario.verify(70, 0.05) + # traces = scenario.simulate(30, 1) + # path = os.path.abspath('.') + # if os.path.exists(path+'/demo'): + # path += '/demo' + # path += '/output' + # if not os.path.exists(path): + # os.makedirs(path) + # file = path+"/output.json" + # write_json(traces, file) + + # root = read_json(file) + # fig = go.Figure() + # fig = simulation_tree(root, tmp_map, fig, 1, 2, 'lines') + # fig.show() + # # traces = scenario.verify(70, 0.05) + # fig = go.Figure() + # fig = simulation_tree(traces, tmp_map, fig, 1, 2, + # 'lines', 'trace', print_dim_list=[1, 2]) + # # # fig = reachtube_anime(traces, tmp_map, fig, 1, + # # # 2, 'lines', 'trace', print_dim_list=[1, 2]) + # fig.show() fig = go.Figure() - fig = simulation_anime_trail(traces, tmp_map, fig, 1, 2, 'lines') - # fig = reachtube_anime(traces, tmp_map, fig, 1, - # 2, 'lines', 'trace', print_dim_list=[1, 2]) + traces = scenario.verify(30, 1) + fig = reachtube_anime(traces, tmp_map, fig, 1, + 2, 'lines', 'trace', print_dim_list=[1, 2]) fig.show() diff --git a/demo/demo3.py b/demo/vehicle/demo3.py similarity index 65% rename from demo/demo3.py rename to demo/vehicle/demo3.py index 32508c0251e0e8474fa6a9e3081e777dea6748ae..2c9295cfe4d77a9309e522ad019217a4759a8405 100644 --- a/demo/demo3.py +++ b/demo/vehicle/demo3.py @@ -3,6 +3,7 @@ from dryvr_plus_plus import Scenario from enum import Enum, auto from dryvr_plus_plus.plotter.plotter2D import * +import plotly.graph_objects as go class LaneObjectMode(Enum): Vehicle = auto() @@ -11,20 +12,17 @@ class LaneObjectMode(Enum): Signal = auto() # Traffic lights Obstacle = auto() # Static (to road/lane) obstacles - class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() SwitchRight = auto() Brake = auto() - class LaneMode(Enum): Lane0 = auto() Lane1 = auto() Lane2 = auto() - class State: x = 0.0 y = 0.0 @@ -39,7 +37,7 @@ class State: if __name__ == "__main__": - input_code_name = './controller/example_controller4.py' + input_code_name = './demo/vehicle/controller/example_controller4.py' scenario = Scenario() scenario.add_agent(CarAgent('car1', file_name=input_code_name)) @@ -50,10 +48,10 @@ if __name__ == "__main__": scenario.set_map(tmp_map) scenario.set_init( [ - [[0, -0.2, 0, 1.0], [0.01, 0.2, 0, 1.0]], - [[10, 0, 0, 0.5], [10, 0, 0, 0.5]], - [[20, 3, 0, 0.5], [20, 3, 0, 0.5]], - [[30, 0, 0, 0.5], [30, 0, 0, 0.5]], + [[0, -0.2, 0, 1.0],[0.01, 0.2, 0, 1.0]], + [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], + [[20, 3, 0, 0.5],[20, 3, 0, 0.5]], + [[30, 0, 0, 0.5],[30, 0, 0, 0.5]], ], [ (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), @@ -62,27 +60,17 @@ if __name__ == "__main__": (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), ] ) - # traces = scenario.simulate(70, 0.05) - - # fig = go.Figure() - # fig = simulation_anime_trail(traces, tmp_map, fig, 1, - # 2, 'lines', 'trace', print_dim_list=[1, 2]) - # fig.show() - # fig = go.Figure() - # fig = simulation_anime(traces, tmp_map, fig, 1, - # 2, 'lines', 'trace', print_dim_list=[1, 2]) - # fig.show() - # fig = go.Figure() - # fig = simulation_tree(traces, tmp_map, fig, 1, - # 2, 'lines', 'trace', print_dim_list=[1, 2]) - # fig.show() + traces = scenario.simulate(70, 0.05) + # traces.dump('./output1.json') + fig = go.Figure() + fig = simulation_tree(traces, tmp_map, fig, 1, + 2, 'lines', 'trace', print_dim_list=[1, 2]) + fig.show() traces = scenario.verify(70, 0.1) - # fig = go.Figure() - # fig = reachtube_anime(traces, tmp_map, fig, 1, - # 2, 'lines', 'trace', print_dim_list=[1, 2]) - # fig.show() + # traces.dump('./output2.json') fig = go.Figure() fig = reachtube_tree(traces, tmp_map, fig, 1, 2, 'lines', 'trace', print_dim_list=[1, 2]) fig.show() + diff --git a/demo/demo4.py b/demo/vehicle/demo4.py similarity index 90% rename from demo/demo4.py rename to demo/vehicle/demo4.py index 9316da449398b91ffeb7180699ef26b17963f66a..fb957854743d8224b00ff0f9154255906d95c407 100644 --- a/demo/demo4.py +++ b/demo/vehicle/demo4.py @@ -4,6 +4,7 @@ from dryvr_plus_plus.plotter.plotter2D import * from enum import Enum, auto import plotly.graph_objects as go +from enum import Enum, auto class LaneObjectMode(Enum): @@ -42,7 +43,7 @@ class State: if __name__ == "__main__": - input_code_name = './demo/controller/example_controller8.py' + input_code_name = './demo/vehicle/controller/example_controller8.py' scenario = Scenario() car = CarAgent('car1', file_name=input_code_name) @@ -86,15 +87,15 @@ if __name__ == "__main__": ] ) - # traces = scenario.simulate(80, 0.1) - # fig = go.Figure() + traces = scenario.simulate(80, 0.1) + fig = go.Figure() # fig = simulation_anime( # traces, tmp_map, fig, 1, 2, 'lines', print_dim_list=[1, 2]) # fig.show() # fig = go.Figure() - # fig = simulation_tree( - # traces, tmp_map, fig, 1, 2, 'lines', print_dim_list=[1, 2]) - # fig.show() + fig = simulation_tree( + traces, tmp_map, fig, 1, 2, 'lines', print_dim_list=[1, 2]) + fig.show() traces = scenario.verify(80, 0.1) fig = go.Figure() diff --git a/demo/demo5.py b/demo/vehicle/demo5.py similarity index 84% rename from demo/demo5.py rename to demo/vehicle/demo5.py index b67f55f3b875318142b9dc70d9737d937f92382f..4ac12688bd724e7c9dbe8c2b2df626ba1e6cac30 100644 --- a/demo/demo5.py +++ b/demo/vehicle/demo5.py @@ -26,7 +26,7 @@ class LaneMode(Enum): if __name__ == "__main__": - input_code_name = './demo/controller/example_controller7.py' + input_code_name = './demo/vehicle/controller/example_controller7.py' scenario = Scenario() car = CarAgent('car1', file_name=input_code_name) @@ -61,10 +61,14 @@ if __name__ == "__main__": (LaneObjectMode.Vehicle,), ] ) - # traces = scenario.simulate(70, 0.05) traces = scenario.simulate(70, 0.05) + fig = go.Figure() + fig = simulation_tree(traces, tmp_map, fig, 1, + 2, 'lines', 'trace', print_dim_list=[1, 2]) + fig.show() + traces = scenario.verify(70, 0.05) fig = go.Figure() - fig = simulation_anime_trail(traces, tmp_map, fig, 1, + fig = reachtube_tree(traces, tmp_map, fig, 1, 2, 'lines', 'trace', print_dim_list=[1, 2]) fig.show() diff --git a/demo/demo6.py b/demo/vehicle/demo6.py similarity index 90% rename from demo/demo6.py rename to demo/vehicle/demo6.py index 2d974e792081c1410ae0e5e760b9161e44a79685..63f4da418baabfec5a010d6d7c8a3c0d7330fdbc 100644 --- a/demo/demo6.py +++ b/demo/vehicle/demo6.py @@ -41,7 +41,7 @@ class LaneMode(Enum): if __name__ == "__main__": - input_code_name = './demo/controller/example_controller7.py' + input_code_name = './demo/vehicle/controller/example_controller7.py' scenario = Scenario() car = CarAgent('car1', file_name=input_code_name) @@ -89,7 +89,6 @@ if __name__ == "__main__": (LaneObjectMode.Vehicle,), ] ) - traces = scenario.simulate(80, 0.05) # traces = scenario.verify(80, 0.05) # fig = plt.figure(2) @@ -102,7 +101,12 @@ if __name__ == "__main__": # fig = plot_reachtube_tree(traces, 'car6', 1, [2], 'r', fig) # plt.show() + traces = scenario.simulate(80, 0.05) fig = go.Figure() - fig = simulation_anime(traces, tmp_map, fig, 1, - 2, 'lines', 'trace', print_dim_list=[1, 2]) + fig = simulation_tree(traces, tmp_map, fig, 1, 2, 'lines', 'trace', print_dim_list=[1, 2]) fig.show() + + traces = scenario.verify(80, 0.05) + fig = go.Figure() + fig = reachtube_tree(traces, tmp_map, fig, 1, 2, 'lines', 'trace', print_dim_list=[1,2]) + fig.show() \ No newline at end of file diff --git a/demo/demo7.py b/demo/vehicle/demo7.py similarity index 89% rename from demo/demo7.py rename to demo/vehicle/demo7.py index 71032326138a7b03f793572cb09220368d49b19a..4dbafd2e95ba8d94776bf2f610070336bf7a193d 100644 --- a/demo/demo7.py +++ b/demo/vehicle/demo7.py @@ -3,11 +3,11 @@ from dryvr_plus_plus.example import CarAgent, NPCAgent, SimpleMap4 from dryvr_plus_plus import Scenario from dryvr_plus_plus.plotter.plotter2D import * +from dryvr_plus_plus.analysis import AnalysisTree from enum import Enum, auto import plotly.graph_objects as go - class LaneObjectMode(Enum): Vehicle = auto() Ped = auto() # Pedestrians @@ -44,10 +44,9 @@ class State: if __name__ == "__main__": - input_code_name = './demo/controller/example_controller8.py' + input_code_name = './demo/vehicle/controller/example_controller8.py' scenario = Scenario() - car = CarAgent('car1', file_name=input_code_name) scenario.add_agent(car) car = NPCAgent('car2') @@ -97,16 +96,16 @@ if __name__ == "__main__": (LaneObjectMode.Vehicle,), (LaneObjectMode.Vehicle,), ], - - ) - traces = scenario.simulate(20, 0.05) - # traces = scenario.verify(15, 0.05) - # fig = go.Figure() - # fig = simulation_anime(traces, tmp_map, fig, 1, - # 2, 'lines', 'trace', print_dim_list=[1, 2]) - # fig.show() + ) + traces = scenario.simulate(60, 0.05) fig = go.Figure() fig = simulation_tree(traces, tmp_map, fig, 1, 2, 'lines', 'trace', print_dim_list=[1, 2]) fig.show() + + traces = scenario.verify(60, 0.05) + fig = go.Figure() + fig = reachtube_tree(traces, tmp_map, fig, 1, + 2, 'lines', 'trace', print_dim_list=[1, 2]) + fig.show() \ No newline at end of file diff --git a/demo/demo8.py b/demo/vehicle/demo8.py similarity index 93% rename from demo/demo8.py rename to demo/vehicle/demo8.py index a7e1158603e75c8bd69538346de6632ecf21e942..6207cf8d1c4db74b6b5ea7130eab2ece779be0f4 100644 --- a/demo/demo8.py +++ b/demo/vehicle/demo8.py @@ -40,7 +40,7 @@ class State: if __name__ == "__main__": - input_code_name = './demo/controller/example_controller9.py' + input_code_name = './demo/vehicle/controller/example_controller9.py' scenario = Scenario() car = CarAgent('car1', file_name=input_code_name) diff --git a/dryvr_plus_plus/analysis/__init__.py b/dryvr_plus_plus/analysis/__init__.py index 0e99ae1521bab2784e681061fe1fa57caef67be5..5a4b8bda142af94fe791720ec9ed740d86a2f8a3 100644 --- a/dryvr_plus_plus/analysis/__init__.py +++ b/dryvr_plus_plus/analysis/__init__.py @@ -1,4 +1,4 @@ -from . import simulator, verifier, analysis_tree_node -from .analysis_tree_node import AnalysisTreeNode +from . import simulator, verifier, analysis_tree +from .analysis_tree import * from .simulator import Simulator from .verifier import Verifier diff --git a/dryvr_plus_plus/analysis/analysis_tree.py b/dryvr_plus_plus/analysis/analysis_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..d04d0736d14ce05327b55fed49b0b4c9d2b3fb75 --- /dev/null +++ b/dryvr_plus_plus/analysis/analysis_tree.py @@ -0,0 +1,122 @@ +from typing import List, Dict, Any +import json + +class AnalysisTreeNode: + """AnalysisTreeNode class + A AnalysisTreeNode stores the continous execution of the system without transition happening""" + trace: Dict + """The trace for each agent. + The key of the dict is the agent id and the value of the dict is simulated traces for each agent""" + init: Dict + + def __init__( + self, + trace={}, + init={}, + mode={}, + static = {}, + agent={}, + assert_hits={}, + child=[], + start_time = 0, + ndigits = 10, + type = 'simtrace', + id = 0 + ): + self.trace:Dict = trace + self.init: Dict[str, List[float]] = init + self.mode: Dict[str, List[str]] = mode + self.agent: Dict = agent + self.child: List[AnalysisTreeNode] = child + self.start_time: float = round(start_time, ndigits) + self.assert_hits = assert_hits + self.type: str = type + self.static: Dict[str, List[str]] = static + self.id: int = id + + def to_dict(self): + rst_dict = { + 'id': self.id, + 'parent': None, + 'child': [], + 'agent': {}, + 'init': self.init, + 'mode': self.mode, + 'static': self.static, + 'start_time': self.start_time, + 'trace': self.trace, + 'type': self.type, + 'assert_hits': self.assert_hits + } + agent_dict = {} + for agent_id in self.agent: + agent_dict[agent_id] = f'{type(self.agent[agent_id])}' + rst_dict['agent'] = agent_dict + + return rst_dict + + @staticmethod + def from_dict(data) -> "AnalysisTreeNode": + return AnalysisTreeNode( + trace = data['trace'], + init = data['init'], + mode = data['mode'], + static = data['static'], + agent = data['agent'], + assert_hits = data['assert_hits'], + child = [], + start_time = data['start_time'], + type = data['type'], + ) + +class AnalysisTree: + def __init__(self, root): + self.root:AnalysisTreeNode = root + self.nodes:List[AnalysisTreeNode] = self.get_all_nodes(root) + + def get_all_nodes(self, root: AnalysisTreeNode) -> List[AnalysisTreeNode]: + # Perform BFS/DFS to store all the tree node in a list + res = [] + queue = [root] + node_id = 0 + while queue: + node = queue.pop(0) + node.id = node_id + res.append(node) + node_id += 1 + queue += node.child + return res + + def dump(self, fn): + res_dict = {} + converted_node = self.root.to_dict() + res_dict[self.root.id] = converted_node + queue = [self.root] + while queue: + parent_node = queue.pop(0) + for child_node in parent_node.child: + node_dict = child_node.to_dict() + node_dict['parent'] = parent_node.id + res_dict[child_node.id] = node_dict + res_dict[parent_node.id]['child'].append(child_node.id) + queue.append(child_node) + + with open(fn,'w+') as f: + json.dump(res_dict,f, indent=4, sort_keys=True) + + @staticmethod + def load(fn): + f = open(fn, 'r') + data = json.load(f) + f.close() + root_node_dict = data[str(0)] + root = AnalysisTreeNode.from_dict(root_node_dict) + queue = [(root_node_dict, root)] + while queue: + parent_node_dict, parent_node = queue.pop(0) + for child_node_idx in parent_node_dict['child']: + child_node_dict = data[str(child_node_idx)] + child_node = AnalysisTreeNode.from_dict(child_node_dict) + parent_node.child.append(child_node) + queue.append((child_node_dict, child_node)) + return AnalysisTree(root) \ No newline at end of file diff --git a/dryvr_plus_plus/analysis/analysis_tree_node.py b/dryvr_plus_plus/analysis/analysis_tree_node.py deleted file mode 100644 index a0f8cd814c48c84d8e43b643d3879d1bd065eb4d..0000000000000000000000000000000000000000 --- a/dryvr_plus_plus/analysis/analysis_tree_node.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import List, Dict, Any - - -class AnalysisTreeNode: - """AnalysisTreeNode class - A AnalysisTreeNode stores the continous execution of the system without transition happening""" - trace: Dict - """The trace for each agent. - The key of the dict is the agent id and the value of the dict is simulated traces for each agent""" - init: Dict - - def __init__( - self, - trace={}, - init={}, - mode={}, - static = {}, - agent={}, - assert_hits={}, - child=[], - start_time=0, - ndigits=10, - type='simtrace' - ): - self.trace: Dict = trace - self.init: Dict[str, List[float]] = init - self.mode: Dict[str, List[str]] = mode - self.agent: Dict = agent - self.child: List[AnalysisTreeNode] = child - self.start_time: float = round(start_time, ndigits) - self.assert_hits = assert_hits - self.type: str = type - self.static: Dict[str, List[str]] = static diff --git a/dryvr_plus_plus/analysis/simulator.py b/dryvr_plus_plus/analysis/simulator.py index 2298c145c492fca7e67dbaef47defec06b4d6c52..d8503485b05c512525f135ffac67cace48fdd693 100644 --- a/dryvr_plus_plus/analysis/simulator.py +++ b/dryvr_plus_plus/analysis/simulator.py @@ -1,18 +1,19 @@ from typing import List, Dict import copy -import itertools, functools +import itertools +import functools import pprint pp = functools.partial(pprint.pprint, compact=True, width=100) from dryvr_plus_plus.agents.base_agent import BaseAgent -from dryvr_plus_plus.analysis.analysis_tree_node import AnalysisTreeNode +from dryvr_plus_plus.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree class Simulator: def __init__(self): - self.simulation_tree_root = None + self.simulation_tree = None - def simulate(self, init_list, init_mode_list, static_list, agent_list:List[BaseAgent], transition_graph, time_horizon, time_step, lane_map): + def simulate(self, init_list, init_mode_list, static_list, agent_list: List[BaseAgent], transition_graph, time_horizon, time_step, lane_map): # Setup the root of the simulation tree root = AnalysisTreeNode( trace={}, @@ -31,21 +32,20 @@ class Simulator: root.static[agent.id] = init_static root.agent[agent.id] = agent root.type = 'simtrace' - self.simulation_tree_root = root + simulation_queue = [] simulation_queue.append(root) # Perform BFS through the simulation tree to loop through all possible transitions while simulation_queue != []: - node:AnalysisTreeNode = simulation_queue.pop(0) + node: AnalysisTreeNode = simulation_queue.pop(0) pp((node.start_time, node.mode)) - remain_time = round(time_horizon - node.start_time,10) + remain_time = round(time_horizon - node.start_time, 10) if remain_time <= 0: continue # For trace not already simulated for agent_id in node.agent: if agent_id not in node.trace: # Simulate the trace starting from initial condition - # [time, x, y, theta, v] mode = node.mode[agent_id] init = node.init[agent_id] trace = node.agent[agent_id].TC_simulate( @@ -53,7 +53,8 @@ class Simulator: trace[:, 0] += node.start_time node.trace[agent_id] = trace.tolist() - asserts, transitions, transition_idx = transition_graph.get_transition_simulate_new(node) + asserts, transitions, transition_idx = transition_graph.get_transition_simulate_new( + node) node.assert_hits = asserts pp({a: trace[transition_idx] for a, trace in node.trace.items()}) @@ -74,16 +75,18 @@ class Simulator: else: # Generate the transition combinations if multiple agents can transit at the same time step transition_list = list(transitions.values()) - all_transition_combinations = itertools.product(*transition_list) + all_transition_combinations = itertools.product( + *transition_list) - # For each possible transition, construct the new node. + # For each possible transition, construct the new node. # Obtain the new initial condition for agent having transition # copy the traces that are not under transition for transition_combination in all_transition_combinations: - next_node_mode = copy.deepcopy(node.mode) + next_node_mode = copy.deepcopy(node.mode) next_node_static = copy.deepcopy(node.static) - next_node_agent = node.agent - next_node_start_time = list(truncated_trace.values())[0][0][0] + next_node_agent = node.agent + next_node_start_time = list( + truncated_trace.values())[0][0][0] next_node_init = {} next_node_trace = {} for transition in transition_combination: @@ -91,21 +94,21 @@ class Simulator: if dest_mode is None: continue # next_node = AnalysisTreeNode(trace = {},init={},mode={},agent={}, child = [], start_time = 0) - next_node_mode[transit_agent_idx] = dest_mode - next_node_init[transit_agent_idx] = next_init + next_node_mode[transit_agent_idx] = dest_mode + next_node_init[transit_agent_idx] = next_init for agent_idx in next_node_agent: if agent_idx not in next_node_init: next_node_trace[agent_idx] = truncated_trace[agent_idx] - + tmp = AnalysisTreeNode( - trace = next_node_trace, - init = next_node_init, - mode = next_node_mode, - static = next_node_static, - agent = next_node_agent, - child = [], - start_time = next_node_start_time, - type = 'simtrace' + trace=next_node_trace, + init=next_node_init, + mode=next_node_mode, + static=next_node_static, + agent=next_node_agent, + child=[], + start_time=next_node_start_time, + type='simtrace' ) node.child.append(tmp) simulation_queue.append(tmp) @@ -119,4 +122,6 @@ class Simulator: # start_time = next_node_start_time # )) # simulation_queue += node.child - return self.simulation_tree_root + + self.simulation_tree = AnalysisTree(root) + return self.simulation_tree diff --git a/dryvr_plus_plus/analysis/verifier.py b/dryvr_plus_plus/analysis/verifier.py index ac4fea85f0d06705ca03132b8ffd9fde3f5ef09b..f15622ac95a8f7199103b8d32e515f8e4c9d5f86 100644 --- a/dryvr_plus_plus/analysis/verifier.py +++ b/dryvr_plus_plus/analysis/verifier.py @@ -4,44 +4,44 @@ import copy import numpy as np from dryvr_plus_plus.agents.base_agent import BaseAgent -from dryvr_plus_plus.analysis.analysis_tree_node import AnalysisTreeNode +from dryvr_plus_plus.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree from dryvr_plus_plus.reachability_engine.dryvr import calc_bloated_tube, SIMTRACENUM + class Verifier: def __init__(self): - self.reachtube_tree_root = None + self.reachtube_tree = None self.unsafe_set = None - self.verification_result = None + self.verification_result = None def compute_full_reachtube( self, init_list: List[float], init_mode_list: List[str], static_list: List[str], - agent_list:List[BaseAgent], - transition_graph, - time_horizon, - time_step, + agent_list: List[BaseAgent], + transition_graph, + time_horizon, + time_step, lane_map ): root = AnalysisTreeNode() for i, agent in enumerate(agent_list): root.init[agent.id] = init_list[i] init_mode = [elem.name for elem in init_mode_list[i]] - root.mode[agent.id] = init_mode + root.mode[agent.id] = init_mode init_static = [elem.name for elem in static_list[i]] root.static[agent.id] = init_static - root.agent[agent.id] = agent + root.agent[agent.id] = agent root.type = 'reachtube' - self.reachtube_tree_root = root verification_queue = [] verification_queue.append(root) while verification_queue != []: - node:AnalysisTreeNode = verification_queue.pop(0) - print(node.start_time, node.mode) - remain_time = round(time_horizon - node.start_time,10) + node: AnalysisTreeNode = verification_queue.pop(0) + # print(node.start_time, node.mode) + remain_time = round(time_horizon - node.start_time, 10) if remain_time <= 0: - continue + continue # For reachtubes not already computed # TODO: can add parallalization for this loop for agent_id in node.agent: @@ -64,7 +64,7 @@ class Verifier: lane_map = lane_map ) trace = np.array(cur_bloated_tube) - trace[:,0] += node.start_time + trace[:, 0] += node.start_time node.trace[agent_id] = trace.tolist() # print("here") @@ -79,38 +79,38 @@ class Verifier: max_end_idx = 0 for transition in all_possible_transitions: - transit_agent_idx, src_mode, dest_mode, next_init, idx = transition + transit_agent_idx, src_mode, dest_mode, next_init, idx = transition start_idx, end_idx = idx - + truncated_trace = {} for agent_idx in node.agent: truncated_trace[agent_idx] = node.trace[agent_idx][start_idx*2:] if end_idx > max_end_idx: max_end_idx = end_idx - + if dest_mode is None: continue - - next_node_mode = copy.deepcopy(node.mode) - next_node_mode[transit_agent_idx] = dest_mode - next_node_agent = node.agent + + next_node_mode = copy.deepcopy(node.mode) + next_node_mode[transit_agent_idx] = dest_mode + next_node_agent = node.agent next_node_start_time = list(truncated_trace.values())[0][0][0] next_node_init = {} next_node_trace = {} for agent_idx in next_node_agent: if agent_idx == transit_agent_idx: - next_node_init[agent_idx] = next_init + next_node_init[agent_idx] = next_init else: next_node_trace[agent_idx] = truncated_trace[agent_idx] - + tmp = AnalysisTreeNode( - trace = next_node_trace, - init = next_node_init, - mode = next_node_mode, - agent = next_node_agent, - child = [], - start_time = round(next_node_start_time,10), - type = 'reachtube' + trace=next_node_trace, + init=next_node_init, + mode=next_node_mode, + agent=next_node_agent, + child=[], + start_time=round(next_node_start_time, 10), + type='reachtube' ) node.child.append(tmp) verification_queue.append(tmp) @@ -119,5 +119,8 @@ class Verifier: """Only truncate when there's transitions""" if all_possible_transitions: for agent_idx in node.agent: - node.trace[agent_idx] = node.trace[agent_idx][:(max_end_idx+1)*2] - return root + node.trace[agent_idx] = node.trace[agent_idx][:( + max_end_idx+1)*2] + + self.reachtube_tree = AnalysisTree(root) + return self.reachtube_tree diff --git a/dryvr_plus_plus/automaton/guard.py b/dryvr_plus_plus/automaton/guard.py index ec844a0a396d5c9d147d1d400b0897b3a5dbc120..364c5ba61ac321b28a48b1600583e905e961a57d 100644 --- a/dryvr_plus_plus/automaton/guard.py +++ b/dryvr_plus_plus/automaton/guard.py @@ -10,13 +10,12 @@ from dryvr_plus_plus.agents.base_agent import BaseAgent from dryvr_plus_plus.code_parser import Reduction, ReductionType, unparse class LogicTreeNode: - def __init__(self, data, child=[], val=None, mode_guard=None): - self.data = data + def __init__(self, data, child = [], val = None, mode_guard = None): + self.data = data self.child = child self.val = val self.mode_guard = mode_guard - class NodeSubstituter(ast.NodeTransformer): def __init__(self, old_node, new_node): super().__init__() @@ -81,11 +80,11 @@ class ValueSubstituter(ast.NodeTransformer): class GuardExpressionAst: - def __init__(self, guard_list): + def __init__(self, guard_list, guard_idx = 0): self.ast_list = copy.deepcopy(guard_list) self.cont_variables = {} self.varDict = {} - + self.guard_idx = guard_idx def _build_guard(self, guard_str, agent): """ @@ -122,7 +121,7 @@ class GuardExpressionAst: self.cont_variables[cont_vars] = underscored self.varDict[underscored] = Real(underscored) - z3_string = self.generate_z3_expression() + z3_string = self.generate_z3_expression() if isinstance(z3_string, bool): return z3_string, z3_string @@ -135,13 +134,13 @@ class GuardExpressionAst: # The reachtube hits the guard cur_solver.pop() res = True + tmp_solver = Solver() tmp_solver.add(Not(cur_solver.assertions()[0])) for symbol in symbols: start, end = continuous_variable_dict[symbols[symbol]] tmp_solver.add(self.varDict[symbol] >= start, self.varDict[symbol] <= end) if tmp_solver.check() == unsat: - print("Full intersect, break") is_contained = True return res, is_contained @@ -182,7 +181,7 @@ class GuardExpressionAst: If without evaluating the continuous variables the result is True, then the guard condition will automatically be satisfied - + If without evaluating the continuous variables the result is False, then the guard condition will not be satisfied @@ -194,11 +193,11 @@ class GuardExpressionAst: # For each value in the boolop, check results if isinstance(node.op, ast.And): z3_str = [] - for i, val in enumerate(node.values): + for i,val in enumerate(node.values): tmp = self._generate_z3_expression_node(val) if isinstance(tmp, bool): if tmp: - continue + continue else: return False z3_str.append(tmp) @@ -224,7 +223,7 @@ class GuardExpressionAst: return z3_str # If string, construct string # If bool, check result and discard/evaluate result according to operator - pass + pass elif isinstance(node, ast.Constant): # If is bool, return boolean result if isinstance(node.value, bool): @@ -235,7 +234,7 @@ class GuardExpressionAst: expr = expr.strip('\n') return expr elif isinstance(node, ast.UnaryOp): - # If is UnaryOp, + # If is UnaryOp, value = self._generate_z3_expression_node(node.operand) if isinstance(node.op, ast.USub): return -value @@ -250,7 +249,7 @@ class GuardExpressionAst: expr = expr.strip('\n') return expr - def evaluate_guard_hybrid(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map: LaneMap): + def evaluate_guard_hybrid(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map:LaneMap): """ Handle guard atomics that contains both continuous and hybrid variables Especially, we want to handle function calls that need both continuous and @@ -260,11 +259,10 @@ class GuardExpressionAst: By doing this, all calls that need both continuous and discrete variables as input will now become only continuous variables. We can then handle these using what we already have for the continous variables """ - res = True + res = True for i, node in enumerate(self.ast_list): - tmp, self.ast_list[i] = self._evaluate_guard_hybrid( - node, agent, discrete_variable_dict, continuous_variable_dict, lane_map) - res = res and tmp + tmp, self.ast_list[i] = self._evaluate_guard_hybrid(node, agent, discrete_variable_dict, continuous_variable_dict, lane_map) + res = res and tmp return res def _evaluate_guard_hybrid(self, root, agent, disc_var_dict, cont_var_dict, lane_map:LaneMap): @@ -277,28 +275,24 @@ class GuardExpressionAst: if isinstance(root.op, ast.And): res = True for i, val in enumerate(root.values): - tmp, root.values[i] = self._evaluate_guard_hybrid( - val, agent, disc_var_dict, cont_var_dict, lane_map) - res = res and tmp + tmp, root.values[i] = self._evaluate_guard_hybrid(val, agent, disc_var_dict, cont_var_dict, lane_map) + res = res and tmp if not res: - break - return res, root + break + return res, root elif isinstance(root.op, ast.Or): res = False for val in root.values: - tmp, val = self._evaluate_guard_hybrid( - val, agent, disc_var_dict, cont_var_dict, lane_map) + tmp,val = self._evaluate_guard_hybrid(val, agent, disc_var_dict, cont_var_dict, lane_map) res = res or tmp return res, root elif isinstance(root, ast.BinOp): - left, root.left = self._evaluate_guard_hybrid( - root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.right = self._evaluate_guard_hybrid( - root.right, agent, disc_var_dict, cont_var_dict, lane_map) + left, root.left = self._evaluate_guard_hybrid(root.left, agent, disc_var_dict, cont_var_dict, lane_map) + right, root.right = self._evaluate_guard_hybrid(root.right, agent, disc_var_dict, cont_var_dict, lane_map) return True, root elif isinstance(root, ast.Call): if isinstance(root.func, ast.Attribute): - func = root.func + func = root.func if func.value.id == 'lane_map': if func.attr == 'get_lateral_distance': # Get function arguments @@ -326,21 +320,16 @@ class GuardExpressionAst: vehicle_pos = (arg1_lower, arg1_upper) # Get corresponding lane segments with respect to the set of vehicle pos - lane_seg1 = lane_map.get_lane_segment( - vehicle_lane, arg1_lower) - lane_seg2 = lane_map.get_lane_segment( - vehicle_lane, arg1_upper) + lane_seg1 = lane_map.get_lane_segment(vehicle_lane, arg1_lower) + lane_seg2 = lane_map.get_lane_segment(vehicle_lane, arg1_upper) # Compute the set of possible lateral values with respect to all possible segments - lateral_set1 = self._handle_lateral_set( - lane_seg1, np.array(vehicle_pos)) - lateral_set2 = self._handle_lateral_set( - lane_seg2, np.array(vehicle_pos)) + lateral_set1 = self._handle_lateral_set(lane_seg1, np.array(vehicle_pos)) + lateral_set2 = self._handle_lateral_set(lane_seg2, np.array(vehicle_pos)) # Use the union of two sets as the set of possible lateral positions - lateral_set = [min(lateral_set1[0], lateral_set2[0]), max( - lateral_set1[1], lateral_set2[1])] - + lateral_set = [min(lateral_set1[0], lateral_set2[0]), max(lateral_set1[1], lateral_set2[1])] + # Construct the tmp variable tmp_var_name = f'tmp_variable{len(cont_var_dict)+1}' # Add the tmp variable to the cont var dict @@ -375,21 +364,16 @@ class GuardExpressionAst: vehicle_pos = (arg1_lower, arg1_upper) # Get corresponding lane segments with respect to the set of vehicle pos - lane_seg1 = lane_map.get_lane_segment( - vehicle_lane, arg1_lower) - lane_seg2 = lane_map.get_lane_segment( - vehicle_lane, arg1_upper) + lane_seg1 = lane_map.get_lane_segment(vehicle_lane, arg1_lower) + lane_seg2 = lane_map.get_lane_segment(vehicle_lane, arg1_upper) # Compute the set of possible longitudinal values with respect to all possible segments - longitudinal_set1 = self._handle_longitudinal_set( - lane_seg1, np.array(vehicle_pos)) - longitudinal_set2 = self._handle_longitudinal_set( - lane_seg2, np.array(vehicle_pos)) + longitudinal_set1 = self._handle_longitudinal_set(lane_seg1, np.array(vehicle_pos)) + longitudinal_set2 = self._handle_longitudinal_set(lane_seg2, np.array(vehicle_pos)) # Use the union of two sets as the set of possible longitudinal positions - longitudinal_set = [min(longitudinal_set1[0], longitudinal_set2[0]), max( - longitudinal_set1[1], longitudinal_set2[1])] - + longitudinal_set = [min(longitudinal_set1[0], longitudinal_set2[0]), max(longitudinal_set1[1], longitudinal_set2[1])] + # Construct the tmp variable tmp_var_name = f'tmp_variable{len(cont_var_dict)+1}' # Add the tmp variable to the cont var dict @@ -404,7 +388,7 @@ class GuardExpressionAst: else: raise ValueError(f'Node type {root.func} from {unparse(root.func)} is not supported') elif isinstance(root, ast.Attribute): - return True, root + return True, root elif isinstance(root, ast.Constant): return root.value, root elif isinstance(root, ast.Name): @@ -426,26 +410,24 @@ class GuardExpressionAst: def _handle_longitudinal_set(self, lane_seg: AbstractLane, position: np.ndarray) -> List[float]: if lane_seg.type == "Straight": # Delta lower - delta0 = position[0, :] - lane_seg.start + delta0 = position[0,:] - lane_seg.start # Delta upper - delta1 = position[1, :] - lane_seg.start + delta1 = position[1,:] - lane_seg.start longitudinal_low = min(delta0[0]*lane_seg.direction[0], delta1[0]*lane_seg.direction[0]) + \ - min(delta0[1]*lane_seg.direction[1], - delta1[1]*lane_seg.direction[1]) + min(delta0[1]*lane_seg.direction[1], delta1[1]*lane_seg.direction[1]) longitudinal_high = max(delta0[0]*lane_seg.direction[0], delta1[0]*lane_seg.direction[0]) + \ - max(delta0[1]*lane_seg.direction[1], - delta1[1]*lane_seg.direction[1]) + max(delta0[1]*lane_seg.direction[1], delta1[1]*lane_seg.direction[1]) longitudinal_low += lane_seg.longitudinal_start longitudinal_high += lane_seg.longitudinal_start assert longitudinal_high >= longitudinal_low - return longitudinal_low, longitudinal_high + return longitudinal_low, longitudinal_high elif lane_seg.type == "Circular": # Delta lower - delta0 = position[0, :] - lane_seg.center + delta0 = position[0,:] - lane_seg.center # Delta upper - delta1 = position[1, :] - lane_seg.center + delta1 = position[1,:] - lane_seg.center phi0 = np.min([ np.arctan2(delta0[1], delta0[0]), @@ -460,66 +442,50 @@ class GuardExpressionAst: np.arctan2(delta1[1], delta1[0]), ]) - phi0 = lane_seg.start_phase + \ - wrap_to_pi(phi0 - lane_seg.start_phase) - phi1 = lane_seg.start_phase + \ - wrap_to_pi(phi1 - lane_seg.start_phase) + phi0 = lane_seg.start_phase + wrap_to_pi(phi0 - lane_seg.start_phase) + phi1 = lane_seg.start_phase + wrap_to_pi(phi1 - lane_seg.start_phase) longitudinal_low = min( - lane_seg.direction * - (phi0 - lane_seg.start_phase)*lane_seg.radius, - lane_seg.direction * - (phi1 - lane_seg.start_phase)*lane_seg.radius + lane_seg.direction * (phi0 - lane_seg.start_phase)*lane_seg.radius, + lane_seg.direction * (phi1 - lane_seg.start_phase)*lane_seg.radius ) + lane_seg.longitudinal_start longitudinal_high = max( - lane_seg.direction * - (phi0 - lane_seg.start_phase)*lane_seg.radius, - lane_seg.direction * - (phi1 - lane_seg.start_phase)*lane_seg.radius + lane_seg.direction * (phi0 - lane_seg.start_phase)*lane_seg.radius, + lane_seg.direction * (phi1 - lane_seg.start_phase)*lane_seg.radius ) + lane_seg.longitudinal_start assert longitudinal_high >= longitudinal_low return longitudinal_low, longitudinal_high else: - raise ValueError( - f'Lane segment with type {lane_seg.type} is not supported') + raise ValueError(f'Lane segment with type {lane_seg.type} is not supported') def _handle_lateral_set(self, lane_seg: AbstractLane, position: np.ndarray) -> List[float]: if lane_seg.type == "Straight": # Delta lower - delta0 = position[0, :] - lane_seg.start + delta0 = position[0,:] - lane_seg.start # Delta upper - delta1 = position[1, :] - lane_seg.start + delta1 = position[1,:] - lane_seg.start lateral_low = min(delta0[0]*lane_seg.direction_lateral[0], delta1[0]*lane_seg.direction_lateral[0]) + \ - min(delta0[1]*lane_seg.direction_lateral[1], - delta1[1]*lane_seg.direction_lateral[1]) + min(delta0[1]*lane_seg.direction_lateral[1], delta1[1]*lane_seg.direction_lateral[1]) lateral_high = max(delta0[0]*lane_seg.direction_lateral[0], delta1[0]*lane_seg.direction_lateral[0]) + \ - max(delta0[1]*lane_seg.direction_lateral[1], - delta1[1]*lane_seg.direction_lateral[1]) + max(delta0[1]*lane_seg.direction_lateral[1], delta1[1]*lane_seg.direction_lateral[1]) assert lateral_high >= lateral_low return lateral_low, lateral_high elif lane_seg.type == "Circular": - dx = np.max([position[0, 0]-lane_seg.center[0], - 0, lane_seg.center[0]-position[1, 0]]) - dy = np.max([position[0, 1]-lane_seg.center[1], - 0, lane_seg.center[1]-position[1, 1]]) + dx = np.max([position[0,0]-lane_seg.center[0],0,lane_seg.center[0]-position[1,0]]) + dy = np.max([position[0,1]-lane_seg.center[1],0,lane_seg.center[1]-position[1,1]]) r_low = np.linalg.norm([dx, dy]) - dx = np.max([np.abs(position[0, 0]-lane_seg.center[0]), - np.abs(position[1, 0]-lane_seg.center[0])]) - dy = np.max([np.abs(position[0, 1]-lane_seg.center[1]), - np.abs(position[1, 1]-lane_seg.center[1])]) + dx = np.max([np.abs(position[0,0]-lane_seg.center[0]),np.abs(position[1,0]-lane_seg.center[0])]) + dy = np.max([np.abs(position[0,1]-lane_seg.center[1]),np.abs(position[1,1]-lane_seg.center[1])]) r_high = np.linalg.norm([dx, dy]) - lateral_low = min(lane_seg.direction*(lane_seg.radius - r_high), - lane_seg.direction*(lane_seg.radius - r_low)) - lateral_high = max(lane_seg.direction*(lane_seg.radius - r_high), - lane_seg.direction*(lane_seg.radius - r_low)) + lateral_low = min(lane_seg.direction*(lane_seg.radius - r_high),lane_seg.direction*(lane_seg.radius - r_low)) + lateral_high = max(lane_seg.direction*(lane_seg.radius - r_high),lane_seg.direction*(lane_seg.radius - r_low)) # print(lateral_low, lateral_high) assert lateral_high >= lateral_low return lateral_low, lateral_high else: - raise ValueError( - f'Lane segment with type {lane_seg.type} is not supported') + raise ValueError(f'Lane segment with type {lane_seg.type} is not supported') def evaluate_guard_disc(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map): """ @@ -538,7 +504,7 @@ class GuardExpressionAst: Recursively called function to evaluate guard with only discrete variables The function will evaluate all guards with discrete variables and replace the nodes with discrete guards by boolean constants - + :params: :return: The return value will be a tuple. The first element in the tuple will either be a boolean value or a the evaluated value of of an expression involving guard The second element in the tuple will be the updated ast node @@ -550,30 +516,29 @@ class GuardExpressionAst: if isinstance(left, bool) or isinstance(right, bool): return True, root if isinstance(root.ops[0], ast.GtE): - res = left >= right + res = left>=right elif isinstance(root.ops[0], ast.Gt): - res = left > right + res = left>right elif isinstance(root.ops[0], ast.Lt): - res = left < right + res = left<right elif isinstance(root.ops[0], ast.LtE): - res = left <= right + res = left<=right elif isinstance(root.ops[0], ast.Eq): - res = left == right + res = left == right elif isinstance(root.ops[0], ast.NotEq): - res = left != right + res = left != right else: raise ValueError(f'Node type {root} from {unparse(root)} is not supported') if res: root = ast.parse('True').body[0].value else: - root = ast.parse('False').body[0].value + root = ast.parse('False').body[0].value return res, root elif isinstance(root, ast.BoolOp): if isinstance(root.op, ast.And): res = True - for i, val in enumerate(root.values): - tmp, root.values[i] = self._evaluate_guard_disc( - val, agent, disc_var_dict, cont_var_dict, lane_map) + for i,val in enumerate(root.values): + tmp,root.values[i] = self._evaluate_guard_disc(val, agent, disc_var_dict, cont_var_dict, lane_map) res = res and tmp if not res: root = ast.Constant(value=False, kind=None) @@ -587,10 +552,8 @@ class GuardExpressionAst: return res, root elif isinstance(root, ast.BinOp): # Check left and right in the binop and replace all attributes involving discrete variables - left, root.left = self._evaluate_guard_disc( - root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.right = self._evaluate_guard_disc( - root.right, agent, disc_var_dict, cont_var_dict, lane_map) + left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, cont_var_dict, lane_map) + right, root.right = self._evaluate_guard_disc(root.right, agent, disc_var_dict, cont_var_dict, lane_map) return True, root elif isinstance(root, ast.Call): expr = unparse(root) @@ -610,7 +573,7 @@ class GuardExpressionAst: if res: root = ast.parse('True').body[0].value else: - root = ast.parse('False').body[0].value + root = ast.parse('False').body[0].value else: # # TODO-PARSER: Handle This # for mode_name in agent.controller.mode_defs: @@ -679,22 +642,20 @@ class GuardExpressionAst: def _evaluate_guard(self, root, agent, cnts_var_dict, disc_var_dict, lane_map): if isinstance(root, ast.Compare): - left = self._evaluate_guard( - root.left, agent, cnts_var_dict, disc_var_dict, lane_map) - right = self._evaluate_guard( - root.comparators[0], agent, cnts_var_dict, disc_var_dict, lane_map) + left = self._evaluate_guard(root.left, agent, cnts_var_dict, disc_var_dict, lane_map) + right = self._evaluate_guard(root.comparators[0], agent, cnts_var_dict, disc_var_dict, lane_map) if isinstance(root.ops[0], ast.GtE): - return left >= right + return left>=right elif isinstance(root.ops[0], ast.Gt): - return left > right + return left>right elif isinstance(root.ops[0], ast.Lt): - return left < right + return left<right elif isinstance(root.ops[0], ast.LtE): - return left <= right + return left<=right elif isinstance(root.ops[0], ast.Eq): - return left == right + return left == right elif isinstance(root.ops[0], ast.NotEq): - return left != right + return left != right else: raise ValueError(f'Node type {root} from {unparse(root)} is not supported') @@ -702,8 +663,7 @@ class GuardExpressionAst: if isinstance(root.op, ast.And): res = True for val in root.values: - tmp = self._evaluate_guard( - val, agent, cnts_var_dict, disc_var_dict, lane_map) + tmp = self._evaluate_guard(val, agent, cnts_var_dict, disc_var_dict, lane_map) res = res and tmp if not res: break @@ -711,17 +671,14 @@ class GuardExpressionAst: elif isinstance(root.op, ast.Or): res = False for val in root.values: - tmp = self._evaluate_guard( - val, agent, cnts_var_dict, disc_var_dict, lane_map) + tmp = self._evaluate_guard(val, agent, cnts_var_dict, disc_var_dict, lane_map) res = res or tmp if res: break return res elif isinstance(root, ast.BinOp): - left = self._evaluate_guard( - root.left, agent, cnts_var_dict, disc_var_dict, lane_map) - right = self._evaluate_guard( - root.right, agent, cnts_var_dict, disc_var_dict, lane_map) + left = self._evaluate_guard(root.left, agent, cnts_var_dict, disc_var_dict, lane_map) + right = self._evaluate_guard(root.right, agent, cnts_var_dict, disc_var_dict, lane_map) if isinstance(root.op, ast.Sub): return left - right elif isinstance(root.op, ast.Add): @@ -743,7 +700,7 @@ class GuardExpressionAst: for arg in disc_var_dict: expr = expr.replace(arg, f'"{disc_var_dict[arg]}"') for arg in cnts_var_dict: - expr = expr.replace(arg, str(cnts_var_dict[arg])) + expr = expr.replace(arg, str(cnts_var_dict[arg])) res = eval(expr) # TODO-PARSER: Handle This for mode_name in agent.controller.mode_defs: @@ -774,8 +731,7 @@ class GuardExpressionAst: elif isinstance(root, ast.Constant): return root.value elif isinstance(root, ast.UnaryOp): - val = self._evaluate_guard( - root.operand, agent, cnts_var_dict, disc_var_dict, lane_map) + val = self._evaluate_guard(root.operand, agent, cnts_var_dict, disc_var_dict, lane_map) if isinstance(root.op, ast.USub): return -val if isinstance(root.op, ast.Not): @@ -1060,9 +1016,9 @@ class GuardExpressionAst: return root if __name__ == "__main__": - with open('tmp.pickle', 'rb') as f: + with open('tmp.pickle','rb') as f: guard_list = pickle.load(f) tmp = GuardExpressionAst(guard_list) # tmp.evaluate_guard() # tmp.construct_tree_from_str('(other_x-ego_x<20) and other_x-ego_x>10 and other_vehicle_lane==ego_vehicle_lane') - print("stop") + print("stop") \ No newline at end of file diff --git a/dryvr_plus_plus/code_parser/parser.py b/dryvr_plus_plus/code_parser/parser.py index e94fdd4575dcd753846ff7452e1e4c85ab4d8dcc..4e3260ab6ffcc7d15cbbd53893b26db54b8cf96c 100644 --- a/dryvr_plus_plus/code_parser/parser.py +++ b/dryvr_plus_plus/code_parser/parser.py @@ -1,4 +1,4 @@ -import ast, copy +import ast, copy, warnings from typing import List, Dict, Union, Optional, Any, Tuple from dataclasses import dataclass, field, fields from enum import Enum, auto @@ -741,6 +741,10 @@ def proc(node: ast.AST, env: Env) -> Any: return proc(node.value, env) if node.value != None else None elif isinstance(node, ast.IfExp): return node + elif isinstance(node, ast.Expr): + if isinstance(node.value, ast.Call): + warnings.warn(f"Effects of this call will not be included in the result: \"{unparse(node.value)}\"") + return None # Literals elif isinstance(node, ast.List): diff --git a/dryvr_plus_plus/example/example_agent/__init__.py b/dryvr_plus_plus/example/example_agent/__init__.py index 1f841e3b9b8c66e9d6c12b7c89d661ee182f6fc5..f5189e9d315e1dacf8ca5042ce7b19581e069139 100644 --- a/dryvr_plus_plus/example/example_agent/__init__.py +++ b/dryvr_plus_plus/example/example_agent/__init__.py @@ -1,4 +1,5 @@ -from . import ball_agent, car_agent, sign_agent +from . import ball_agent, car_agent, sign_agent, quadrotor_agent from .ball_agent import BallAgent from .car_agent import CarAgent, NPCAgent -from .sign_agent import SignAgent \ No newline at end of file +from .sign_agent import SignAgent +from .quadrotor_agent import QuadrotorAgent \ No newline at end of file diff --git a/dryvr_plus_plus/example/example_agent/ball_agent.py b/dryvr_plus_plus/example/example_agent/ball_agent.py index 22aa9419581bddfac8f321219460f18baa276704..2628d55f31299939cd0983e016f4e9c36e1b0e11 100644 --- a/dryvr_plus_plus/example/example_agent/ball_agent.py +++ b/dryvr_plus_plus/example/example_agent/ball_agent.py @@ -11,7 +11,8 @@ from dryvr_plus_plus import LaneMap class BallAgent(BaseAgent): '''Dynamics of a frictionless billiard ball on a 2D-plane''' - def __init__(self, id, code = None, file_name = None): + + def __init__(self, id, code=None, file_name=None): '''Contructor for tha agent EXACTLY one of the following should be given file_name: name of the controller @@ -30,25 +31,27 @@ class BallAgent(BaseAgent): vy_dot = 0 return [x_dot, y_dot, vx_dot, vy_dot] - def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map:LaneMap=None)->np.ndarray: + def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map: LaneMap = None) -> np.ndarray: # TODO: P1. Should TC_simulate really be part of the agent definition or should it be something more generic? # TODO: P2. Looks like this should be a global parameter; some config file should be setting this. time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) - t = [round(i*time_step,10) for i in range(0,number_points)] + t = [round(i*time_step, 10) for i in range(0, number_points)] init = initialCondition trace = [[0]+init] for i in range(len(t)): r = ode(self.dynamic) r.set_initial_value(init) - res:np.ndarray = r.integrate(r.t + time_step) + res: np.ndarray = r.integrate(r.t + time_step) init = res.flatten().tolist() trace.append([t[i] + time_step] + init) return np.array(trace) + if __name__ == '__main__': - aball = BallAgent('red_ball',file_name="/Users/mitras/Dpp/GraphGeneration/demo/ball_bounces.py") - trace = aball.TC_simulate({'none'},[5, 10, 2, 2],10) + aball = BallAgent( + 'red_ball', file_name="/Users/mitras/Dpp/GraphGeneration/demo/ball_bounces.py") + trace = aball.TC_simulate({'none'}, [5, 10, 2, 2], 10,0.05) print(trace) diff --git a/dryvr_plus_plus/example/example_agent/car_agent.py b/dryvr_plus_plus/example/example_agent/car_agent.py index 445a3ae0fe90dfd8565aacb578f4cae2781670cc..5aee874c1ce9c6835a9761322d00b9728abf0102 100644 --- a/dryvr_plus_plus/example/example_agent/car_agent.py +++ b/dryvr_plus_plus/example/example_agent/car_agent.py @@ -1,14 +1,13 @@ # Example agent. from typing import Tuple, List -import numpy as np +import numpy as np from scipy.integrate import ode from dryvr_plus_plus import BaseAgent from dryvr_plus_plus import LaneMap from dryvr_plus_plus.code_parser import ControllerIR - class NPCAgent(BaseAgent): def __init__(self, id): self.id = id @@ -17,86 +16,76 @@ class NPCAgent(BaseAgent): @staticmethod def dynamic(t, state, u): x, y, theta, v = state - delta, a = u + delta, a = u x_dot = v*np.cos(theta+delta) y_dot = v*np.sin(theta+delta) theta_dot = v/1.75*np.sin(delta) - v_dot = a + v_dot = a return [x_dot, y_dot, theta_dot, v_dot] - - def action_handler(self, mode, state, lane_map: LaneMap) -> Tuple[float, float]: + def action_handler(self, mode, state, lane_map:LaneMap)->Tuple[float, float]: ''' Computes steering and acceleration based on current lane, target lane and current state using a Stanley controller-like rule''' - x, y, theta, v = state + x,y,theta,v = state vehicle_mode = mode[0] vehicle_lane = mode[1] - vehicle_pos = np.array([x, y]) + vehicle_pos = np.array([x,y]) d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) psi = lane_map.get_lane_heading(vehicle_lane, vehicle_pos)-theta steering = psi + np.arctan2(0.45*d, v) steering = np.clip(steering, -0.61, 0.61) a = 0 - return steering, a - + return steering, a def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map:LaneMap=None)->np.ndarray: time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) - t = [i*time_step for i in range(0, number_points)] + t = [i*time_step for i in range(0,number_points)] init = initialCondition trace = [[0]+init] for i in range(len(t)): steering, a = self.action_handler(mode, init, lane_map) - r = ode(self.dynamic) - r.set_initial_value(init).set_f_params([steering, a]) - res: np.ndarray = r.integrate(r.t + time_step) + r = ode(self.dynamic) + r.set_initial_value(init).set_f_params([steering, a]) + res:np.ndarray = r.integrate(r.t + time_step) init = res.flatten().tolist() - trace.append([t[i] + time_step] + init) + trace.append([t[i] + time_step] + init) return np.array(trace) - class CarAgent(BaseAgent): - def __init__(self, id, code=None, file_name=None): + def __init__(self, id, code = None, file_name = None): super().__init__(id, code, file_name) @staticmethod def dynamic(t, state, u): x, y, theta, v = state - delta, a = u + delta, a = u x_dot = v*np.cos(theta+delta) y_dot = v*np.sin(theta+delta) theta_dot = v/1.75*np.sin(delta) - v_dot = a + v_dot = a return [x_dot, y_dot, theta_dot, v_dot] - def action_handler(self, mode: List[str], state, lane_map: LaneMap) -> Tuple[float, float]: - x, y, theta, v = state + def action_handler(self, mode: List[str], state, lane_map:LaneMap)->Tuple[float, float]: + x,y,theta,v = state vehicle_mode = mode[0] vehicle_lane = mode[1] - vehicle_pos = np.array([x, y]) + vehicle_pos = np.array([x,y]) a = 0 if vehicle_mode == "Normal": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - # # keyi: just toy mod - # if v <= 2: - # a = 0.2 elif vehicle_mode == "SwitchLeft": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) + 3 elif vehicle_mode == "SwitchRight": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - 3 elif vehicle_mode == "Brake": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - a = -1 - if v <= 0.02: - a = 0 - elif vehicle_mode == "Accelerate": + a = -1 + elif vehicle_mode == "Accel": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) a = 1 - if v >= lane_map.get_speed_limit(vehicle_lane)-0.02: - a = 0 elif vehicle_mode == 'Stop': d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) a = 0 @@ -106,25 +95,24 @@ class CarAgent(BaseAgent): psi = lane_map.get_lane_heading(vehicle_lane, vehicle_pos)-theta steering = psi + np.arctan2(0.45*d, v) steering = np.clip(steering, -0.61, 0.61) - return steering, a - + return steering, a def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map:LaneMap=None)->np.ndarray: time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) t = [round(i*time_step,10) for i in range(0,number_points)] + init = initialCondition - # [time, x, y, theta, v] trace = [[0]+init] for i in range(len(t)): steering, a = self.action_handler(mode, init, lane_map) - r = ode(self.dynamic) - r.set_initial_value(init).set_f_params([steering, a]) - res: np.ndarray = r.integrate(r.t + time_step) + r = ode(self.dynamic) + r.set_initial_value(init).set_f_params([steering, a]) + res:np.ndarray = r.integrate(r.t + time_step) init = res.flatten().tolist() if init[3] < 0: init[3] = 0 - trace.append([t[i] + time_step] + init) + trace.append([t[i] + time_step] + init) return np.array(trace) diff --git a/dryvr_plus_plus/example/example_agent/origin_agent.py b/dryvr_plus_plus/example/example_agent/origin_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb95bdae7249643afb98732ffde97a6d248a41d --- /dev/null +++ b/dryvr_plus_plus/example/example_agent/origin_agent.py @@ -0,0 +1,159 @@ +# Example agent. +from typing import Tuple, List + +import numpy as np +from scipy.integrate import ode + +from dryvr_plus_plus.agents import BaseAgent +from dryvr_plus_plus.map import LaneMap + + +class vanderpol_agent(BaseAgent): + def __init__(self, id, code=None, file_name=None): + # Calling the constructor of tha base class + super().__init__(id, code, file_name) + + @staticmethod + def dynamic(t, state): + x, y = state + x = float(x) + y = float(y) + x_dot = y + y_dot = (1-x**2)*y - x + return [x_dot, y_dot] + + def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map: LaneMap = None) -> np.ndarray: + time_bound = float(time_bound) + number_points = int(np.ceil(time_bound/time_step)) + t = [round(i*time_step, 10) for i in range(0, number_points)] + # note: digit of time + init = initialCondition + trace = [[0]+init] + for i in range(len(t)): + r = ode(self.dynamic) + r.set_initial_value(init) + res: np.ndarray = r.integrate(r.t + time_step) + init = res.flatten().tolist() + trace.append([t[i] + time_step] + init) + return np.array(trace) + + +class thermo_agent(BaseAgent): + def __init__(self, id, code=None, file_name=None): + # Calling the constructor of tha base class + super().__init__(id, code, file_name) + + @staticmethod + def dynamic(t, state, rate): + temp, total_time, cycle_time = state + temp = float(temp) + total_time = float(total_time) + cycle_time = float(cycle_time) + temp_dot = temp*rate + total_time_dot = 1 + cycle_time_dot = 1 + return [temp_dot, total_time_dot, cycle_time_dot] + + def action_handler(self, mode): + if mode == 'ON': + rate = 0.1 + elif mode == 'OFF': + rate = -0.1 + else: + print(mode) + raise ValueError(f'Invalid mode: {mode}') + return rate + + def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map: LaneMap = None) -> np.ndarray: + time_bound = float(time_bound) + number_points = int(np.ceil(time_bound/time_step)) + t = [round(i*time_step, 10) for i in range(0, number_points)] + + init = initialCondition + trace = [[0]+init] + for i in range(len(t)): + rate = self.action_handler(mode[0]) + r = ode(self.dynamic) + r.set_initial_value(init).set_f_params(rate) + res: np.ndarray = r.integrate(r.t + time_step) + init = res.flatten().tolist() + trace.append([t[i] + time_step] + init) + return np.array(trace) + + +class craft_agent(BaseAgent): + def __init__(self, id, code=None, file_name=None): + # Calling the constructor of tha base class + super().__init__(id, code, file_name) + + @staticmethod + def ProxA_dynamics(t, state): + xp, yp, xd, yd, total_time, cycle_time = state + xp_dot = xd + yp_dot = yd + xd_dot = -2.89995083970656*xd - 0.0576765518445905*xp + 0.00877200894463775*yd + 0.000200959896519766 * \ + yp - (1.43496e+18*xp + 6.050365344e+25)*pow(pow(yp, 2) + + pow(xp + 42164000, 2), -1.5) + 807.153595726846 + yd_dot = -0.00875351105536225*xd - 0.000174031357370456*xp - 2.90300269286856*yd - \ + 1.43496e+18*yp*pow(pow(yp, 2) + pow(xp + 42164000, + 2), -1.5) - 0.0664932019993982*yp + total_time_dot = 1 + cycle_time_dot = 1 + return [xp_dot, yp_dot, xd_dot, yd_dot, total_time_dot, cycle_time_dot] + + @staticmethod + def ProxB_dynamics(t, state): + xp, yp, xd, yd, total_time, cycle_time = state + xp_dot = xd + yp_dot = yd + xd_dot = -19.2299795908647*xd - 0.576076729033652*xp + 0.00876275931760007*yd + 0.000262486079431672 * \ + yp - (1.43496e+18*xp + 6.050365344e+25)*pow(pow(yp, 2) + + pow(xp + 42164000, 2), -1.5) + 807.153595726846 + yd_dot = -0.00876276068239993*xd - 0.000262486080737868*xp - 19.2299765959399*yd - \ + 1.43496e+18*yp*pow(pow(yp, 2) + pow(xp + 42164000, + 2), -1.5) - 0.575980743701182*yp + total_time_dot = 1 + cycle_time_dot = 1 + return [xp_dot, yp_dot, xd_dot, yd_dot, total_time_dot, cycle_time_dot] + + @staticmethod + def Passive_dynamics(t, state): + xp, yp, xd, yd, total_time, cycle_time = state + xp_dot = xd + yp_dot = yd + xd_dot = 0.0000575894721132000*xp+0.00876276*yd + yd_dot = -0.00876276*xd + total_time_dot = 1 + cycle_time_dot = 1 + return [xp_dot, yp_dot, xd_dot, yd_dot, total_time_dot, cycle_time_dot] + + def action_handler(self, mode): + if mode == 'ProxA': + return ode(self.ProxA_dynamics) + elif mode == 'ProxB': + return ode(self.ProxB_dynamics) + elif mode == 'Passive': + return ode(self.Passive_dynamics) + else: + raise ValueError + + def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map: LaneMap = None) -> np.ndarray: + time_bound = float(time_bound) + number_points = int(np.ceil(time_bound/time_step)) + t = [round(i*time_step, 10) for i in range(0, number_points)] + + init = initialCondition + trace = [[0]+init] + for i in range(len(t)): + r = self.action_handler(mode[0]) + r.set_initial_value(init) + res: np.ndarray = r.integrate(r.t + time_step) + init = res.flatten().tolist() + trace.append([t[i] + time_step] + init) + return np.array(trace) + + +if __name__ == '__main__': + aball = vanderpol_agent('agent1') + trace = aball.TC_Simulate(['none'], [1.25, 2.25], 7, 0.05) + print(trace) diff --git a/dryvr_plus_plus/example/example_agent/prarm.json b/dryvr_plus_plus/example/example_agent/prarm.json new file mode 100644 index 0000000000000000000000000000000000000000..2bb05cb7de1a4af54276aa83a5d1493345cc58f4 --- /dev/null +++ b/dryvr_plus_plus/example/example_agent/prarm.json @@ -0,0 +1,184 @@ +{ + "bias1" : [0.005067077465355396, 0.013770184479653835, 0.02960527129471302, -0.005076207220554352, + 0.00986097939312458, -0.004963981453329325, 0.005067163147032261, 0.022097138687968254, + 0.005067095160484314, -0.009861057624220848, -0.005051563028246164, -0.013770153746008873, + 0.009860980324447155, 0.009861058555543423, 0.013770055025815964, 0.005067105405032635, + -0.01376994326710701, 0.005067072343081236, -0.005067067686468363, -0.013769976794719696], + "bias2" : [-0.009237067773938179, -0.00437132129445672, 0.0007612668559886515, 0.009431629441678524, + -0.04935901612043381, 0.00892704352736473, 0.00891881249845028, -0.042356036603450775, 0.03377627208828926, + 0.014071502722799778, -0.0018434594385325909, -0.0006053714314475656, -0.0038432874716818333, + -0.007012223359197378, -0.0007034881855361164, 0.007561248727142811, -0.042776428163051605, + 0.009373181499540806, 0.0031296780798584223, 0.008734943345189095], + "bias3" : [-0.8870829939842224, -1.152485728263855, -1.3024290800094604, -1.1338839530944824, -0.12526285648345947, + -0.35318782925605774, -0.9353211522102356, -1.0099754333496094], + "weight1" : [ + [-0.012771551497280598, 0.8522672653198242, 0.005031325854361057, 0.010413140058517456, 0.5228086709976196, + 0.0038109351880848408], + [0.01777890883386135, 0.011899493634700775, -0.8380187153816223, 0.010888529941439629, 0.011424682103097439, + -0.5449933409690857], + [0.01807912439107895, 0.012226282618939877, -0.8379306793212891, 0.011077952571213245, 0.011494635604321957, + -0.5451062917709351], + [0.01277169119566679, -0.8522672057151794, -0.005030765198171139, -0.010413102805614471, -0.5228087902069092, + -0.0038107011932879686], + [-0.8513970971107483, -0.01639718748629093, -0.015279111452400684, -0.5237646102905273, -0.014861795119941235, + -0.008371188305318356], + [0.012770231813192368, -0.8522675633430481, -0.005036665592342615, -0.010414107702672482, -0.5228080749511719, + -0.0038130702450871468], + [-0.012771555222570896, 0.8522672653198242, 0.005031300242990255, 0.01041315495967865, 0.5228086709976196, + 0.0038109151646494865], + [-0.8513400554656982, -0.01752384752035141, -0.01597718894481659, -0.5237532258033752, -0.01596754789352417, + -0.009223726578056812], + [-0.012771563604474068, 0.8522672057151794, 0.005031331907957792, 0.010413155891001225, 0.5228087902069092, + 0.0038109072484076023], + [0.8513972163200378, 0.016397155821323395, 0.015279067680239677, 0.5237646698951721, 0.014861810952425003, + 0.008371211588382721], + [0.012771359644830227, -0.8522672653198242, -0.005032065790146589, -0.010413287207484245, -0.5228086709976196, + -0.003811270697042346], + [-0.01777891255915165, -0.011899505741894245, 0.8380184173583984, -0.010888525284826756, -0.011424724943935871, + 0.5449937582015991], + [-0.8513973355293274, -0.016397180035710335, -0.01527908630669117, -0.5237643718719482, -0.014861813746392727, + -0.008371248841285706], + [-0.8513972163200378, -0.01639718748629093, -0.0152790741994977, -0.5237645506858826, -0.01486178208142519, + -0.008371269330382347], + [0.01777891255915165, 0.011899520643055439, -0.8380185961723328, 0.010888495482504368, 0.011424724012613297, + -0.54499351978302], + [-0.012771572917699814, 0.8522672653198242, 0.005031317938119173, 0.010413181968033314, 0.5228086709976196, + 0.003810951951891184], + [-0.017778869718313217, -0.011899495497345924, 0.838018536567688, -0.010888428427278996, -0.011424727737903595, + 0.54499351978302], + [-0.012771587818861008, 0.8522674441337585, 0.005031261593103409, 0.010413181968033314, 0.5228084325790405, + 0.0038109596353024244], + [0.012771571055054665, -0.8522672057151794, -0.00503126997500658, -0.010413173586130142, -0.5228086113929749, + -0.0038109286688268185], + [-0.017778851091861725, -0.011899485252797604, 0.838018536567688, -0.010888493619859219, -0.0114247165620327, + 0.54499351978302]], + "weight2" : [ + [0.004485692363232374, 0.4092908799648285, 0.40233033895492554, -0.0044856867752969265, -0.008494298905134201, + -0.004485597368329763, 0.004485687240958214, -0.008073708973824978, 0.004485689103603363, 0.008494298905134201, + -0.004485669545829296, -0.4092909097671509, -0.00849429052323103, -0.008494291454553604, 0.4092908799648285, + 0.004485682118684053, -0.4092908799648285, 0.004485685378313065, -0.004485683515667915, -0.4092908799648285], + [-0.0012539782328531146, 0.011227603070437908, 0.010983745567500591, 0.0012537383008748293, + -0.44823628664016724, 0.0012566098012030125, -0.0012539738090708852, -0.4422348737716675, + -0.0012539689196273685, 0.44823625683784485, 0.0012543690390884876, -0.011227612383663654, -0.4482363164424896, + -0.4482361674308777, 0.011227606795728207, -0.001253972644917667, -0.011227593757212162, -0.001253975322470069, + 0.001253973226994276, -0.01122759748250246], + [0.0007343984907492995, -0.011749573983252048, -0.011509227566421032, -0.0007341671735048294, + 0.44823724031448364, -0.0007370298844762146, 0.0007343830075114965, 0.4421602189540863, 0.0007343870238400996, + -0.44823718070983887, -0.0007348008803091943, 0.011749575845897198, 0.44823724031448364, 0.44823721051216125, + -0.011749569326639175, 0.0007343803299590945, 0.011749548837542534, 0.0007343843462876976, + -0.0007343952311202884, 0.011749545112252235], + [0.0032850054558366537, 0.012248189188539982, 0.012028064578771591, -0.003285229206085205, -0.4482342302799225, + -0.0032824152149260044, 0.003285015933215618, -0.44198647141456604, 0.0032850138377398252, 0.44823431968688965, + -0.0032846173271536827, -0.012248193845152855, -0.4482342004776001, -0.4482342600822449, 0.01224818080663681, + 0.0032850129064172506, -0.012248178012669086, 0.003285012673586607, -0.003285010578110814, + -0.012248177081346512], + [-0.33326631784439087, 0.001286709913983941, 0.0012743606930598617, 0.33326631784439087, 0.008707520551979542, + 0.3332684338092804, -0.33326634764671326, 0.009328183718025684, -0.33326634764671326, -0.008707527071237564, + 0.33326682448387146, -0.001286706537939608, 0.008707529865205288, 0.008707528933882713, 0.0012867064215242863, + -0.33326634764671326, -0.001286699203774333, -0.3332662582397461, 0.3332662880420685, -0.0012867129407823086], + [-0.005634149070829153, -0.40926215052604675, -0.4023561179637909, 0.005634183995425701, 0.008391384966671467, + 0.005634055007249117, -0.005634150467813015, 0.007969524711370468, -0.005634148605167866, + -0.008391385897994041, 0.005634150002151728, 0.4092622399330139, 0.008391381241381168, 0.008391374722123146, + -0.40926244854927063, -0.0056341588497161865, 0.40926244854927063, -0.0056341467425227165, + 0.005634146276861429, 0.409262478351593], + [-0.006210127845406532, -0.4092416763305664, -0.4023682475090027, 0.006210171617567539, 0.00855502113699913, + 0.006210030987858772, -0.006210129242390394, 0.008134675212204456, -0.006210132036358118, + -0.008555027656257153, 0.006210120394825935, 0.4092416763305664, 0.008555008098483086, 0.008555013686418533, + -0.4092416763305664, -0.006210125517100096, 0.409241646528244, -0.00621012831106782, 0.006210136227309704, + 0.4092416763305664], + [0.3332020938396454, 0.0034835836850106716, 0.0034986836835741997, -0.3332015573978424, -0.011808326467871666, + -0.3332063853740692, 0.3332020342350006, -0.012397287413477898, 0.3332020342350006, 0.01180832739919424, + -0.3332027792930603, -0.0034835890401154757, -0.011808333918452263, -0.01180832739919424, + 0.0034835883416235447, 0.3332020342350006, -0.0034835846163332462, 0.3332020342350006, -0.333202064037323, + -0.003483586013317108], + [0.3332599997520447, -0.00046248571015894413, -0.0004488642734941095, -0.3332599699497223, + -0.009213998913764954, -0.3332628607749939, 0.3332599997520447, -0.009829509072005749, 0.3332599997520447, + 0.009213999845087528, -0.33326059579849243, 0.00046249059960246086, -0.009214004501700401, + -0.009213997051119804, -0.0004624773282557726, 0.3332599997520447, 0.0004624743014574051, 0.3332599997520447, + -0.3332599997520447, 0.00046248442959040403], + [0.005426416639238596, 0.012804088182747364, 0.012594926171004772, -0.005426647607237101, -0.44819632172584534, + -0.005423834081739187, 0.005426429677754641, -0.44185635447502136, 0.005426422227174044, 0.4481962323188782, + -0.005426030606031418, -0.012804095633327961, -0.4481962323188782, -0.4481961727142334, 0.01280409935861826, + 0.005426430609077215, -0.012804084457457066, 0.00542643154039979, -0.005426420830190182, + -0.012804084457457066], + [-8.470881584798917e-05, -0.01166805811226368, -0.011431642808020115, 8.493004861520603e-05, + 0.44824597239494324, 8.206536585930735e-05, -8.470812463201582e-05, 0.4421427845954895, -8.470044849673286e-05, + -0.4482460021972656, 8.428884029854089e-05, 0.011668059974908829, 0.4482460021972656, 0.44824597239494324, + -0.011668050661683083, -8.470762259094045e-05, 0.011668048799037933, -8.469591557513922e-05, + 8.470559259876609e-05, 0.011668049730360508], + [0.0007104332908056676, -0.012103226035833359, -0.011863806284964085, -0.0007102005183696747, + 0.4482336640357971, -0.0007130496669560671, 0.0007104334654286504, 0.4421180486679077, 0.0007104235119186342, + -0.44823363423347473, -0.0007108247373253107, 0.012103239074349403, 0.4482337534427643, 0.44823357462882996, + -0.012103230692446232, 0.0007104267715476453, 0.012103220447897911, 0.0007104279939085245, + -0.0007104240357875824, 0.012103220447897911], + [-0.0008108518086373806, -0.011628665030002594, -0.011395211331546307, 0.0008111011702567339, + 0.4482516348361969, 0.0008082437561824918, -0.0008108714246191084, 0.44211941957473755, -0.000810873054433614, + -0.4482516050338745, 0.0008104708977043629, 0.011628672480583191, 0.4482516050338745, 0.4482516348361969, + -0.011628661304712296, -0.0008108695619739592, 0.011628646403551102, -0.0008108615875244141, + 0.0008108713664114475, 0.011628646403551102], + [0.0052762399427592754, 0.4092657268047333, 0.4023624062538147, -0.005276260897517204, -0.008586729876697063, + -0.0052761598490178585, 0.005276238080114126, -0.008165040984749794, 0.005276246462017298, + 0.008586726151406765, -0.0052762338891625404, -0.4092656970024109, -0.00858672522008419, -0.008586717769503593, + 0.4092656970024109, 0.005276253912597895, -0.40926575660705566, 0.005276249721646309, -0.005276253912597895, + -0.4092658758163452], + [0.0004666325112339109, -0.01200425997376442, -0.011765711009502411, -0.00046640209620818496, + 0.44823700189590454, -0.00046925980132073164, 0.00046662817476317286, 0.4421238601207733, + 0.0004666206077672541, -0.448236882686615, -0.00046703135012649, 0.012004264630377293, 0.44823694229125977, + 0.4482368230819702, -0.01200425811111927, 0.0004666294262278825, 0.012004253454506397, 0.0004666416789405048, + -0.00046663961256854236, 0.012004251591861248], + [-0.006058018188923597, -0.40924715995788574, -0.4023823142051697, 0.006058032624423504, 0.008350761607289314, + 0.006057909224182367, -0.006058020517230034, 0.007928609848022461, -0.00605802284553647, -0.00835077092051506, + 0.006058006081730127, 0.40924718976020813, 0.00835077092051506, 0.00835077092051506, -0.40924715995788574, + -0.006058022379875183, 0.4092472195625305, -0.006058013532310724, 0.0060580214485526085, 0.40924715995788574], + [0.3332037329673767, 0.0033146513160318136, 0.003330634441226721, -0.33320334553718567, -0.011781765148043633, + -0.3332081437110901, 0.33320361375808716, -0.012370237149298191, 0.3332037031650543, 0.011781767010688782, + -0.33320438861846924, -0.0033146499190479517, -0.011781766079366207, -0.011781767010688782, + 0.003314657835289836, 0.3332037329673767, -0.0033146559726446867, 0.33320367336273193, -0.3332037329673767, + -0.0033146515488624573], + [-0.005700466223061085, -0.4092574417591095, -0.40236255526542664, 0.005700497422367334, 0.008488606661558151, + 0.005700360052287579, -0.005700466688722372, 0.008067479357123375, -0.00570047739893198, -0.008488614112138748, + 0.005700456909835339, 0.4092574417591095, 0.008488607592880726, 0.008488602004945278, -0.4092575013637543, + -0.005700466223061085, 0.4092575013637543, -0.005700462963432074, 0.005700469017028809, 0.4092574715614319], + [0.0007849707617424428, 0.012151014059782028, 0.011918344534933567, -0.0007852140697650611, -0.4482424557209015, + -0.000782380870077759, 0.0007849783287383616, -0.4420732855796814, 0.0007849846151657403, 0.4482424259185791, + -0.0007845927611924708, -0.012151009403169155, -0.4482423961162567, -0.4482423961162567, 0.012150995433330536, + 0.0007849778048694134, -0.012150992639362812, 0.0007849839166738093, -0.0007849822868593037, + -0.012150988914072514], + [-0.005965784657746553, -0.40925177931785583, -0.4023708701133728, 0.005965803749859333, 0.008349047973752022, + 0.005965673830360174, -0.005965787451714277, 0.0079267006367445, -0.005965782329440117, -0.008349056355655193, + 0.0059657711535692215, 0.4092518091201782, 0.008349052630364895, 0.008349040523171425, -0.4092518091201782, + -0.00596578698605299, 0.40925195813179016, -0.00596578698605299, 0.005965790245682001, 0.4092518985271454]], + "weight3" : [ + [-335.4917907714844, 295.4498596191406, -300.9907531738281, 292.68359375, 439.89520263671875, 340.9864807128906, + 341.3768615722656, -446.5185852050781, -442.7583312988281, 283.7157287597656, -300.0808410644531, + -300.40203857421875, -298.2679748535156, -338.4247741699219, -298.95599365234375, 342.16082763671875, + -442.63568115234375, 340.7982482910156, 291.43646240234375, 340.5745544433594], + [333.4602355957031, 283.6884460449219, -290.77532958984375, 284.6348571777344, 458.39697265625, + -331.57843017578125, -330.1046142578125, -458.7938232421875, -459.3420104980469, 277.3214416503906, + -290.0921936035156, -291.56329345703125, -287.3941955566406, 332.5392150878906, -289.87408447265625, + -327.3385009765625, -456.78961181640625, -331.62420654296875, 282.5384521484375, -330.56890869140625], + [-337.1375427246094, 289.10272216796875, -294.45098876953125, 296.2870788574219, -438.8525085449219, + 339.5340881347656, 337.7414855957031, 435.8607177734375, 437.567138671875, 293.1355895996094, + -295.7801208496094, -293.8027038574219, -295.40887451171875, -337.87371826171875, -293.3611755371094, + 339.5439147949219, 437.2536315917969, 338.9185791015625, 289.39752197265625, 338.4055480957031], + [331.2817687988281, 278.56243896484375, -286.8565979003906, 290.38287353515625, -383.39654541015625, + -332.671630859375, -332.4136047363281, 458.2372741699219, 394.3974609375, 288.79791259765625, + -287.97247314453125, -287.65728759765625, -287.8821105957031, 333.19647216796875, -286.55657958984375, + -330.1717834472656, 458.9079895019531, -332.5555114746094, 282.8378601074219, -332.8061828613281], + [-327.96527099609375, -282.6752014160156, 281.63214111328125, -289.4588928222656, 376.37738037109375, + 333.9535217285156, 335.0184020996094, -457.1034851074219, -390.4696960449219, -294.4223327636719, + 281.4463806152344, 281.6940002441406, 284.7673645019531, -332.4997253417969, 284.1268615722656, + 335.67401123046875, -454.1480407714844, 334.2760009765625, -288.58868408203125, 333.94891357421875], + [341.65960693359375, -295.7402648925781, 292.73046875, -298.5085754394531, 437.5282287597656, + -340.9335021972656, -338.32208251953125, -443.3055114746094, -439.6506042480469, -301.7493591308594, + 292.2878723144531, 291.4482116699219, 295.3114013671875, 341.15484619140625, 294.27960205078125, + -337.8680419921875, -440.09320068359375, -340.21820068359375, -298.2076721191406, -340.53692626953125], + [-330.5609130859375, -291.5669860839844, 287.8730773925781, -286.07989501953125, -456.10418701171875, + 332.3810119628906, 332.67425537109375, 451.7678527832031, 455.01104736328125, -287.4740905761719, + 285.10052490234375, 287.6368103027344, 286.7076721191406, -332.6780090332031, 289.9099426269531, + 333.3375549316406, 453.9902648925781, 332.9720764160156, -292.4762878417969, 331.2679748535156], + [340.7230224609375, -301.74951171875, 296.9709777832031, -292.6571960449219, -440.14434814453125, + -343.39404296875, -343.3323059082031, 441.12506103515625, 440.0010070800781, -290.7984924316406, + 294.84613037109375, 295.7033996582031, 296.45263671875, 342.21826171875, 298.08880615234375, + -341.65997314453125, 440.3846130371094, -343.41845703125, -299.17828369140625, -344.0526428222656]] +} \ No newline at end of file diff --git a/dryvr_plus_plus/example/example_agent/quadrotor_agent.py b/dryvr_plus_plus/example/example_agent/quadrotor_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5314f10ab5d02ad9e5df597a72270421fed624 --- /dev/null +++ b/dryvr_plus_plus/example/example_agent/quadrotor_agent.py @@ -0,0 +1,207 @@ +# Example agent. +from typing import Tuple, List +import json +import os +import numpy as np +from scipy.integrate import ode +import torch +import math +from dryvr_plus_plus.agents import BaseAgent +from dryvr_plus_plus.map import LaneMap + +class FFNNC(torch.nn.Module): + def __init__(self, D_in=6, D_out=8): + super(FFNNC, self).__init__() + self.layer1 = torch.nn.Linear(D_in, 20) + self.layer2 = torch.nn.Linear(20, 20) + self.layer3 = torch.nn.Linear(20, D_out) + + def forward(self, x): + x = torch.tanh(self.layer1(x)) + x = torch.tanh(self.layer2(x)) + x = self.layer3(x) + return x + + +class QuadrotorAgent(BaseAgent): + def __init__(self, id, code=None, file_name=None, waypoints=[], boxes=[], time_limits=[]): + super().__init__(id, code, file_name) + self.waypoints = waypoints + self.boxes = boxes + self.time_limits = time_limits + + def in_box(self, state, waypoint_id): + if waypoint_id >= len(self.boxes): + return False + box = self.boxes[int(waypoint_id)] + for i in range(len(box[0])): + if state[i] < box[0][i] or state[i] > box[1][i]: + return False + return True + + @staticmethod + def dynamic(t, state, u): + u1, u2, u3, bx, by, bz, sc, ddf = u # len 7 + vx, vy, vz, waypoint, done_flag = state[6:] # len 11 + sc = -1 * sc + dvx = 9.81 * np.sin(u1) / np.cos(u1) + dvy = -9.81 * np.sin(u2) / np.cos(u2) + tmp1 = dvx * math.cos(sc) - dvy * math.sin(sc) + tmp2 = dvx * math.sin(sc) + dvy * math.cos(sc) + dvx = tmp1 + dvy = tmp2 + dvz = u3 - 9.81 + dx = vx + dy = vy + dz = vz + dref_x = bx + dref_y = by + dref_z = bz + dwaypoint = 0 + ddone_flag = ddf + return [dref_x, dref_y, dref_z, dx, dy, dz, dvx, dvy, dvz, dwaypoint, ddone_flag] + + def action_handler(self, state) -> Tuple[float, float]: + waypoint = state[-2] + df = 0 + if self.in_box(state[3:9], waypoint): + df = 1 + return df + + def runModel(self, initalCondition, time_bound, time_step, ref_input): + path = os.path.abspath(__file__) + path = path.replace('quadrotor_agent.py', 'prarm.json') + # print(path) + with open(path, 'r') as f: + prarms = json.load(f) + bias1 = prarms['bias1'] + bias2 = prarms['bias2'] + bias3 = prarms['bias3'] + weight1 = prarms['weight1'] + weight2 = prarms['weight2'] + weight3 = prarms['weight3'] + + bias1 = torch.FloatTensor(bias1) + bias2 = torch.FloatTensor(bias2) + bias3 = torch.FloatTensor(bias3) + weight1 = torch.FloatTensor(weight1) + weight2 = torch.FloatTensor(weight2) + weight3 = torch.FloatTensor(weight3) + controller = FFNNC() + controller.layer1.weight = torch.nn.Parameter(weight1) + controller.layer2.weight = torch.nn.Parameter(weight2) + controller.layer3.weight = torch.nn.Parameter(weight3) + controller.layer1.bias = torch.nn.Parameter(bias1) + controller.layer2.bias = torch.nn.Parameter(bias2) + controller.layer3.bias = torch.nn.Parameter(bias3) + control_input_list = [[-0.1, -0.1, 7.81], + [-0.1, -0.1, 11.81], + [-0.1, 0.1, 7.81], + [-0.1, 0.1, 11.81], + [0.1, -0.1, 7.81], + [0.1, -0.1, 11.81], + [0.1, 0.1, 7.81], + [0.1, 0.1, 11.81]] + init = initalCondition + trajectory = [init] + r = ode(self.dynamic) + # r.set_initial_value(init) + ex_list = [] + ey_list = [] + ez_list = [] + t = 0 + time = [t] + trace = [[t]] + trace[0].extend(init[3:]) + i = 0 + while t <= time_bound: + ex = trajectory[i][3] - trajectory[i][0] + ey = trajectory[i][4] - trajectory[i][1] + ez = trajectory[i][5] - trajectory[i][2] + evx = trajectory[i][6] - ref_input[0] + evy = trajectory[i][7] - ref_input[1] + evz = trajectory[i][8] - ref_input[2] + + sc = ref_input[3] # math.atan2(dot, det) + + tmp1 = ex * math.cos(sc) - ey * math.sin(sc) + tmp2 = ex * math.sin(sc) + ey * math.cos(sc) + ex = tmp1 + ey = tmp2 + + tmp1 = evx * math.cos(sc) - evy * math.sin(sc) + tmp2 = evx * math.sin(sc) + evy * math.cos(sc) + evx = tmp1 + evy = tmp2 + + data = torch.FloatTensor( + [0.2 * ex, 0.2 * ey, 0.2 * ez, 0.1 * evx, 0.1 * evy, 0.1 * evz]) + res = controller(data) + res = res.detach().numpy() + idx = np.argmax(res) + u = control_input_list[idx] + ref_input[0:3] + [sc] + + df = self.action_handler(init) + u = u+[df] + init = trajectory[i] # len 11 + r = ode(self.dynamic) + r.set_initial_value(init) + r.set_f_params(u) + val = r.integrate(r.t + time_step) + + t += time_step + i += 1 + # print(i,idx,u,res) + trajectory.append(val) + time.append(t) + + ex_list.append(ex) + ey_list.append(ey) + ez_list.append(ez) + trace.append([t]) + # remove the reference trajectory from the trace + trace[i].extend(val[3:]) + return trace + + def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map: LaneMap = None) -> np.ndarray: + # total time_bound remained + time_bound = float(time_bound) + initialCondition[-2] = int(initialCondition[-2]) + time_bound = min(self.time_limits[initialCondition[-2]], time_bound) + number_points = int(np.ceil(time_bound/time_step)) + t = [round(i*time_step, 10) for i in range(0, number_points)] + # todo + # if initialCondition[-2] == 2: + # print('r') + # if mode[0] != 'Follow_Waypoint': + # raise ValueError() + mode_parameters = self.waypoints[initialCondition[-2]] + # print(initialCondition[-2], mode_parameters) + # init = initialCondition + # trace = [[0]+init] + # for i in range(len(t)): + # steering, a = self.action_handler(mode, init, lane_map) + # r = ode(self.dynamic) + # r.set_initial_value(init).set_f_params([steering, a]) + # res: np.ndarray = r.integrate(r.t + time_step) + # init = res.flatten().tolist() + # if init[3] < 0: + # init[3] = 0 + # trace.append([t[i] + time_step] + init) + ref_vx = (mode_parameters[3] - mode_parameters[0]) / time_bound + ref_vy = (mode_parameters[4] - mode_parameters[1]) / time_bound + ref_vz = (mode_parameters[5] - mode_parameters[2]) / time_bound + sym_rot_angle = 0 + trace = self.runModel(mode_parameters[0:3] + list(initialCondition), time_bound, time_step, [ref_vx, ref_vy, ref_vz, + sym_rot_angle]) + return np.array(trace) + +# import json +# import os +# if __name__ == "__main__": +# path = os.path.abspath(__file__) +# path=path.replace('tempCodeRunnerFile.py', 'prarm.json') +# print(path) +# with open(path, 'r') as f: +# prarms = json.load(f) +# print(prarms) diff --git a/dryvr_plus_plus/example/example_map/simple_map2.py b/dryvr_plus_plus/example/example_map/simple_map2.py index 49af38b66ed4f1780158e30c334a5d5b5b5696ab..22cf92da950035e300a08ec2ef176cf63a84694c 100644 --- a/dryvr_plus_plus/example/example_map/simple_map2.py +++ b/dryvr_plus_plus/example/example_map/simple_map2.py @@ -2,20 +2,18 @@ from dryvr_plus_plus import LaneMap, LaneSegment, StraightLane, CircularLane, La import numpy as np - class SimpleMap2(LaneMap): def __init__(self): super().__init__() segment0 = StraightLane( 'seg0', - [0, 0], - [100, 0], + [0,0], + [100,0], 3 ) lane0 = Lane('Lane1', [segment0]) self.add_lanes([lane0]) - class SimpleMap3(LaneMap): def __init__(self): super().__init__() @@ -48,41 +46,6 @@ class SimpleMap3(LaneMap): self.right_lane_dict[lane0.id].append(lane1.id) self.right_lane_dict[lane1.id].append(lane2.id) -class SimpleMap3_v2(LaneMap): - def __init__(self): - super().__init__() - segment0 = StraightLane( - 'Seg0', - [0, 3], - [50, 3], - 3, - speed_limit=[(0, 1), (10, 2)] - ) - lane0 = Lane('Lane0', [segment0], speed_limit=2) - segment1 = StraightLane( - 'seg0', - [0, 0], - [50, 0], - 3, - speed_limit=[(0, 1), (20, 3)] - ) - lane1 = Lane('Lane1', [segment1], speed_limit=1) - segment2 = StraightLane( - 'seg0', - [0, -3], - [50, -3], - 3, - speed_limit=[(0, 1), (25, 2.5)] - ) - lane2 = Lane('Lane2', [segment2], speed_limit=3) - # segment2 = LaneSegment('Lane1', 3) - # self.add_lanes([segment1,segment2]) - self.add_lanes([lane0, lane1, lane2]) - self.left_lane_dict[lane1.id].append(lane0.id) - self.left_lane_dict[lane2.id].append(lane1.id) - self.right_lane_dict[lane0.id].append(lane1.id) - self.right_lane_dict[lane1.id].append(lane2.id) - class SimpleMap4(LaneMap): def __init__(self): super().__init__() @@ -139,58 +102,58 @@ class SimpleMap5(LaneMap): super().__init__() segment0 = StraightLane( 'Seg0', - [0, 3], - [15, 3], + [0,3], + [15,3], 3 ) segment1 = StraightLane( 'Seg1', - [15, 3], - [25, 13], + [15,3], + [25,13], 3 ) segment2 = StraightLane( 'Seg2', - [25, 13], - [50, 13], + [25,13], + [50,13], 3 ) lane0 = Lane('Lane0', [segment0, segment1, segment2]) segment0 = StraightLane( 'seg0', - [0, 0], - [17, 0], + [0,0], + [17,0], 3 ) segment1 = StraightLane( 'seg1', - [17, 0], - [27, 10], + [17,0], + [27,10], 3 ) segment2 = StraightLane( 'seg2', - [27, 10], - [50, 10], + [27,10], + [50,10], 3 ) lane1 = Lane('Lane1', [segment0, segment1, segment2]) segment0 = StraightLane( 'seg0', - [0, -3], - [19, -3], + [0,-3], + [19,-3], 3 ) segment1 = StraightLane( 'seg1', - [19, -3], - [29, 7], + [19,-3], + [29,7], 3 ) segment2 = StraightLane( 'seg2', - [29, 7], - [50, 7], + [29,7], + [50,7], 3 ) lane2 = Lane('Lane2', [segment0, segment1, segment2]) @@ -200,19 +163,18 @@ class SimpleMap5(LaneMap): self.right_lane_dict[lane0.id].append(lane1.id) self.right_lane_dict[lane1.id].append(lane2.id) - class SimpleMap6(LaneMap): def __init__(self): super().__init__() segment0 = StraightLane( 'Seg0', - [0, 3], - [15, 3], + [0,3], + [15,3], 3 ) segment1 = CircularLane( 'Seg1', - [15, 8], + [15,8], 5, np.pi*3/2, np.pi*2, @@ -221,20 +183,20 @@ class SimpleMap6(LaneMap): ) segment2 = StraightLane( 'Seg2', - [20, 8], - [20, 30], + [20,8], + [20,30], 3 ) lane0 = Lane('Lane0', [segment0, segment1, segment2]) segment0 = StraightLane( 'seg0', - [0, 0], - [18, 0], + [0,0], + [18,0], 3 ) segment1 = CircularLane( 'seg1', - [18, 5], + [18,5], 5, 3*np.pi/2, 2*np.pi, @@ -243,20 +205,20 @@ class SimpleMap6(LaneMap): ) segment2 = StraightLane( 'seg2', - [23, 5], - [23, 30], + [23,5], + [23,30], 3 ) lane1 = Lane('Lane1', [segment0, segment1, segment2]) segment0 = StraightLane( 'seg0', - [0, -3], - [21, -3], + [0,-3], + [21,-3], 3 ) segment1 = CircularLane( 'seg1', - [21, 2], + [21,2], 5, np.pi*3/2, np.pi*2, @@ -265,8 +227,8 @@ class SimpleMap6(LaneMap): ) segment2 = StraightLane( 'seg2', - [26, 2], - [26, 30], + [26,2], + [26,30], 3 ) lane2 = Lane('Lane2', [segment0, segment1, segment2]) @@ -276,7 +238,6 @@ class SimpleMap6(LaneMap): self.right_lane_dict[lane0.id].append(lane1.id) self.right_lane_dict[lane1.id].append(lane2.id) - if __name__ == "__main__": test_map = SimpleMap3() print(test_map.left_lane_dict) diff --git a/dryvr_plus_plus/example/example_sensor/craft_sensor.py b/dryvr_plus_plus/example/example_sensor/craft_sensor.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d88195997ad9eadab45dca25851e8b23e925a6 --- /dev/null +++ b/dryvr_plus_plus/example/example_sensor/craft_sensor.py @@ -0,0 +1,68 @@ +import numpy as np + + +def sets(d, thing, attrs, vals): + d.update({thing + "." + k: v for k, v in zip(attrs, vals)}) + + +def adds(d, thing, attrs, vals): + for k, v in zip(attrs, vals): + if thing + '.' + k not in d: + d[thing + '.' + k] = [v] + else: + d[thing + '.' + k].append(v) + + +def set_states_2d_ball(cnts, disc, thing, val): + state, mode, static = val + sets(cnts, thing, ['xp', 'yp', 'xd', 'yd', + 'total_time', 'cycle_time'], state[1:7]) + sets(disc, thing, ["craft_mode"], mode) + + +def set_states_3d_ball(cnts, disc, thing, val): + state, mode, static = val + transp = np.transpose(np.array(state)[:, 1:7]) + assert len(transp) == 6 + sets(cnts, thing, ['xp', 'yp', 'xd', 'yd', + 'total_time', 'cycle_time'], transp) + sets(disc, thing, ["craft_mode"], mode) + + +def add_states_2d_ball(cont, disc, thing, val): + state, mode, static = val + adds(cont, thing, ['xp', 'yp', 'xd', 'yd', + 'total_time', 'cycle_time'], state[1:7]) + adds(disc, thing, ["craft_mode"], mode) + + +def add_states_3d_ball(cont, disc, thing, val): + state, mode, static = val + transp = np.transpose(np.array(state)[:, 1:7]) + assert len(transp) == 6 + adds(cont, thing, ['xp', 'yp', 'xd', 'yd', + 'total_time', 'cycle_time'], transp) + adds(disc, thing, ["craft_mode"], mode) + + +class CraftSensor(): + def sense(self, scenario, agent, state_dict, lane_map): + cont = {} + disc = {} + len_dict = {'others': len(state_dict)-1} + tmp = np.array(list(state_dict.values())[0][0]) + if tmp.ndim < 2: + for agent_id in state_dict: + if agent_id == agent.id: + set_states_2d_ball(cont, disc, 'ego', state_dict[agent_id]) + else: + add_states_2d_ball(cont, disc, 'others', + state_dict[agent_id]) + else: + for agent_id in state_dict: + if agent_id == agent.id: + set_states_3d_ball(cont, disc, "ego", state_dict[agent_id]) + else: + add_states_3d_ball(cont, disc, 'others', + state_dict[agent_id]) + return cont, disc, len_dict diff --git a/dryvr_plus_plus/example/example_sensor/fake_sensor.py b/dryvr_plus_plus/example/example_sensor/fake_sensor.py index 1e9260c46b3e03cf004e7e2d636f395114abea90..0d3be16596e8b81de2516769772dadb938a2fe70 100644 --- a/dryvr_plus_plus/example/example_sensor/fake_sensor.py +++ b/dryvr_plus_plus/example/example_sensor/fake_sensor.py @@ -1,5 +1,6 @@ import numpy as np + class FakeSensor1: def sense(self, scenario, agent, state_dict, lane_map): cnts = {} @@ -14,22 +15,26 @@ class FakeSensor1: disc['ego.lane_mode'] = mode[1] return cnts, disc, {} + def sets(d, thing, attrs, vals): d.update({thing + "." + k: v for k, v in zip(attrs, vals)}) + def adds(d, thing, attrs, vals): for k, v in zip(attrs, vals): if thing + '.' + k not in d: - d[thing + '.' + k] = [v] + d[thing + '.' + k] = [v] else: - d[thing + '.' + k].append(v) - + d[thing + '.' + k].append(v) + + def add_states_2d(cont, disc, thing, val): - state, mode, static = val - adds(cont, thing, ['x','y','theta','v'], state[1:5]) + state, mode, static = val + adds(cont, thing, ['x', 'y', 'theta', 'v'], state[1:5]) adds(disc, thing, ["vehicle_mode", "lane_mode"], mode) adds(disc, thing, ['type'], static) + def add_states_3d(cont, disc, thing, val): state, mode, static = val transp = np.transpose(np.array(state)[:, 1:5]) @@ -38,12 +43,14 @@ def add_states_3d(cont, disc, thing, val): adds(disc, thing, ["vehicle_mode", "lane_mode"], mode) adds(disc, thing, ['type'], static) + def set_states_2d(cnts, disc, thing, val): state, mode, static = val sets(cnts, thing, ["x", "y", "theta", "v"], state[1:5]) sets(disc, thing, ["vehicle_mode", "lane_mode"], mode) sets(disc, thing, ['type'], static) + def set_states_3d(cnts, disc, thing, val): state, mode, static = val transp = np.transpose(np.array(state)[:, 1:5]) @@ -52,6 +59,7 @@ def set_states_3d(cnts, disc, thing, val): sets(disc, thing, ["vehicle_mode", "lane_mode"], mode) sets(disc, thing, ['type'], static) + class FakeSensor2: def sense(self, scenario, agent, state_dict, lane_map): cnts = {} @@ -92,18 +100,19 @@ class FakeSensor2: set_states_3d(cnts, disc, "sign", state_dict["sign"]) return cnts, disc, {} + class FakeSensor3: def sense(self, scenario, agent, state_dict, lane_map): cont = {} disc = {} - len_dict = {'others':len(state_dict)-1} + len_dict = {'others': len(state_dict)-1} tmp = np.array(state_dict['car1'][0]) if tmp.ndim < 2: for agent_id in state_dict: if agent_id == agent.id: set_states_2d(cont, disc, 'ego', state_dict[agent_id]) else: - add_states_2d(cont, disc, 'others', state_dict[agent_id]) + add_states_2d(cont, disc, 'others', state_dict[agent_id]) else: for agent_id in state_dict: if agent_id == agent.id: @@ -117,39 +126,48 @@ def set_states_2d_ball(cnts, disc, thing, val): state, mode = val sets(cnts, thing, ["x", "y", "vx", "vy"], state[1:5]) sets(disc, thing, ["ball_mode", "lane_mode"], mode) + + def set_states_3d_ball(cnts, disc, thing, val): state, mode = val transp = np.transpose(np.array(state)[:, 1:5]) assert len(transp) == 4 sets(cnts, thing, ["x", "y", "vx", "vy"], transp) sets(disc, thing, ["ball_mode", "lane_mode"], mode) + + def add_states_2d_ball(cont, disc, thing, val): state, mode = val - adds(cont, thing, ['x','y','vx','vy'], state[1:5]) + adds(cont, thing, ['x', 'y', 'vx', 'vy'], state[1:5]) adds(disc, thing, ["ball_mode", "lane_mode", "type"], mode) + + def add_states_3d_ball(cont, disc, thing, val): state, mode = val transp = np.transpose(np.array(state)[:, 1:5]) assert len(transp) == 4 - adds(cont, thing, ['x','y','vx','vy'], transp) + adds(cont, thing, ['x', 'y', 'vx', 'vy'], transp) adds(disc, thing, ["ball_mode", "lane_mode", "type"], mode) + class FakeSensor4: def sense(self, scenario, agent, state_dict, lane_map): cont = {} disc = {} - len_dict = {'others':len(state_dict)-1} + len_dict = {'others': len(state_dict)-1} tmp = np.array(list(state_dict.values())[0]) if tmp.ndim < 2: for agent_id in state_dict: if agent_id == agent.id: set_states_2d_ball(cont, disc, 'ego', state_dict[agent_id]) else: - add_states_2d_ball(cont, disc, 'others', state_dict[agent_id]) + add_states_2d_ball(cont, disc, 'others', + state_dict[agent_id]) else: for agent_id in state_dict: if agent_id == agent.id: set_states_3d_ball(cont, disc, "ego", state_dict[agent_id]) else: - add_states_3d_ball(cont, disc, 'others', state_dict[agent_id]) + add_states_3d_ball(cont, disc, 'others', + state_dict[agent_id]) return cont, disc, len_dict diff --git a/dryvr_plus_plus/example/example_sensor/quadrotor_sensor.py b/dryvr_plus_plus/example/example_sensor/quadrotor_sensor.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d6c07d3bb97a4e7b873c3605da8f5554bd753f --- /dev/null +++ b/dryvr_plus_plus/example/example_sensor/quadrotor_sensor.py @@ -0,0 +1,68 @@ +import numpy as np + + +def sets(d, thing, attrs, vals): + d.update({thing + "." + k: v for k, v in zip(attrs, vals)}) + + +def adds(d, thing, attrs, vals): + for k, v in zip(attrs, vals): + if thing + '.' + k not in d: + d[thing + '.' + k] = [v] + else: + d[thing + '.' + k].append(v) + + +def set_states_2d_ball(cnts, disc, thing, val): + state, mode, static = val + sets(cnts, thing, ['x', 'y', 'z', 'vx', 'vy', 'vz', + 'waypoint_index', 'done_flag'], state[1:9]) + sets(disc, thing, ["craft_mode"], mode) + + +def set_states_3d_ball(cnts, disc, thing, val): + state, mode, static = val + transp = np.transpose(np.array(state)[:, 1:9]) + assert len(transp) == 8 + sets(cnts, thing, ['x', 'y', 'z', 'vx', 'vy', 'vz', + 'waypoint_index', 'done_flag'], transp) + sets(disc, thing, ["craft_mode"], mode) + + +def add_states_2d_ball(cont, disc, thing, val): + state, mode, static = val + adds(cont, thing, ['x', 'y', 'z', 'vx', 'vy', 'vz', + 'waypoint_index', 'done_flag'], state[1:9]) + adds(disc, thing, ["craft_mode"], mode) + + +def add_states_3d_ball(cont, disc, thing, val): + state, mode, static = val + transp = np.transpose(np.array(state)[:, 1:9]) + assert len(transp) == 8 + adds(cont, thing, ['x', 'y', 'z', 'vx', 'vy', + 'vz', 'waypoint_index', 'done_flag'], transp) + adds(disc, thing, ["craft_mode"], mode) + + +class QuadrotorSensor(): + def sense(self, scenario, agent, state_dict, lane_map): + cont = {} + disc = {} + len_dict = {'others': len(state_dict)-1} + tmp = np.array(list(state_dict.values())[0][0]) + if tmp.ndim < 2: + for agent_id in state_dict: + if agent_id == agent.id: + set_states_2d_ball(cont, disc, 'ego', state_dict[agent_id]) + else: + add_states_2d_ball(cont, disc, 'others', + state_dict[agent_id]) + else: + for agent_id in state_dict: + if agent_id == agent.id: + set_states_3d_ball(cont, disc, "ego", state_dict[agent_id]) + else: + add_states_3d_ball(cont, disc, 'others', + state_dict[agent_id]) + return cont, disc, len_dict diff --git a/dryvr_plus_plus/example/example_sensor/thermo_sensor.py b/dryvr_plus_plus/example/example_sensor/thermo_sensor.py new file mode 100644 index 0000000000000000000000000000000000000000..19a927bcac2a0427a56c2480ac9b95b5af4a14ac --- /dev/null +++ b/dryvr_plus_plus/example/example_sensor/thermo_sensor.py @@ -0,0 +1,64 @@ +import numpy as np + + +def sets(d, thing, attrs, vals): + d.update({thing + "." + k: v for k, v in zip(attrs, vals)}) + + +def adds(d, thing, attrs, vals): + for k, v in zip(attrs, vals): + if thing + '.' + k not in d: + d[thing + '.' + k] = [v] + else: + d[thing + '.' + k].append(v) + + +def set_states_2d_ball(cnts, disc, thing, val): + state, mode, static = val + sets(cnts, thing, ["temp", "total_time", "cycle_time"], state[1:4]) + sets(disc, thing, ["thermo_mode"], mode) + + +def set_states_3d_ball(cnts, disc, thing, val): + state, mode, static = val + transp = np.transpose(np.array(state)[:, 1:4]) + assert len(transp) == 3 + sets(cnts, thing, ["temp", "total_time", "cycle_time"], transp) + sets(disc, thing, ["thermo_mode"], mode) + + +def add_states_2d_ball(cont, disc, thing, val): + state, mode, static = val + adds(cont, thing, ["temp", "total_time", "cycle_time"], state[1:4]) + adds(disc, thing, ["thermo_mode"], mode) + + +def add_states_3d_ball(cont, disc, thing, val): + state, mode, static = val + transp = np.transpose(np.array(state)[:, 1:4]) + assert len(transp) == 3 + adds(cont, thing, ["temp", "total_time", "cycle_time"], transp) + adds(disc, thing, ["thermo_mode"], mode) + + +class ThermoSensor(): + def sense(self, scenario, agent, state_dict, lane_map): + cont = {} + disc = {} + len_dict = {'others': len(state_dict)-1} + tmp = np.array(list(state_dict.values())[0][0]) + if tmp.ndim < 2: + for agent_id in state_dict: + if agent_id == agent.id: + set_states_2d_ball(cont, disc, 'ego', state_dict[agent_id]) + else: + add_states_2d_ball(cont, disc, 'others', + state_dict[agent_id]) + else: + for agent_id in state_dict: + if agent_id == agent.id: + set_states_3d_ball(cont, disc, "ego", state_dict[agent_id]) + else: + add_states_3d_ball(cont, disc, 'others', + state_dict[agent_id]) + return cont, disc, len_dict diff --git a/dryvr_plus_plus/map/lane.py b/dryvr_plus_plus/map/lane.py index d778f29db3b226ca2f29e5c4693769d1398a0191..27d290e5bef6217bad55662d4a4639f3f6c0cfe0 100644 --- a/dryvr_plus_plus/map/lane.py +++ b/dryvr_plus_plus/map/lane.py @@ -4,14 +4,11 @@ import numpy as np from dryvr_plus_plus.map.lane_segment import AbstractLane - class Lane(): COMPENSATE = 3 - - def __init__(self, id, seg_list: List[AbstractLane], speed_limit=None): + def __init__(self, id, seg_list: List[AbstractLane]): self.id = id self.segment_list: List[AbstractLane] = seg_list - self.speed_limit = speed_limit self._set_longitudinal_start() def _set_longitudinal_start(self): @@ -20,37 +17,28 @@ class Lane(): lane_seg.longitudinal_start = longitudinal_start longitudinal_start += lane_seg.length - def get_lane_segment(self, position: np.ndarray) -> AbstractLane: + def get_lane_segment(self, position:np.ndarray) -> AbstractLane: for seg_idx, segment in enumerate(self.segment_list): logitudinal, lateral = segment.local_coordinates(position) - # why COMPENSATE? why no lateral? use on_lane? is_on = 0-Lane.COMPENSATE <= logitudinal < segment.length if is_on: return seg_idx, segment - return -1, None + return -1,None - def get_heading(self, position: np.ndarray) -> float: + def get_heading(self, position:np.ndarray) -> float: seg_idx, segment = self.get_lane_segment(position) longitudinal, lateral = segment.local_coordinates(position) heading = segment.heading_at(longitudinal) return heading - def get_longitudinal_position(self, position: np.ndarray) -> float: + def get_longitudinal_position(self, position:np.ndarray) -> float: seg_idx, segment = self.get_lane_segment(position) longitudinal, lateral = segment.local_coordinates(position) - for i in range(seg_idx): + for i in range(seg_idx): longitudinal += self.segment_list[i].length return longitudinal - def get_lateral_distance(self, position: np.ndarray) -> float: + def get_lateral_distance(self, position:np.ndarray) -> float: seg_idx, segment = self.get_lane_segment(position) longitudinal, lateral = segment.local_coordinates(position) return lateral - - def get_speed_limit_old(self, position: np.ndarray) -> float: - seg_idx, segment = self.get_lane_segment(position) - longitudinal, lateral = segment.local_coordinates(position) - return segment.speed_limit_at(longitudinal) - - def get_speed_limit(self): - return self.speed_limit diff --git a/dryvr_plus_plus/map/lane_map.py b/dryvr_plus_plus/map/lane_map.py index c4703d467cd2a690a77e4afd50fd4b690e7c6789..370d2fb72e3d1e92e0df20beed29aea42714fdd4 100644 --- a/dryvr_plus_plus/map/lane_map.py +++ b/dryvr_plus_plus/map/lane_map.py @@ -7,19 +7,17 @@ import numpy as np from dryvr_plus_plus.map.lane_segment import AbstractLane from dryvr_plus_plus.map.lane import Lane - class LaneMap: - def __init__(self, lane_seg_list: List[Lane] = []): - self.lane_dict: Dict[str, Lane] = {} - self.left_lane_dict: Dict[str, List[str]] = {} - self.right_lane_dict: Dict[str, List[str]] = {} + def __init__(self, lane_seg_list:List[Lane] = []): + self.lane_dict:Dict[str, Lane] = {} + self.left_lane_dict:Dict[str, List[str]] = {} + self.right_lane_dict:Dict[str, List[str]] = {} for lane_seg in lane_seg_list: self.lane_dict[lane_seg.id] = lane_seg self.left_lane_dict[lane_seg.id] = [] self.right_lane_dict[lane_seg.id] = [] - # why AbstractLane not Lane - def add_lanes(self, lane_seg_list: List[AbstractLane]): + def add_lanes(self, lane_seg_list:List[AbstractLane]): for lane_seg in lane_seg_list: self.lane_dict[lane_seg.id] = lane_seg self.left_lane_dict[lane_seg.id] = [] @@ -32,7 +30,7 @@ class LaneMap: Warning(f'lane {lane_idx} not available') return False left_lane_list = self.left_lane_dict[lane_idx] - return len(left_lane_list) > 0 + return len(left_lane_list)>0 def left_lane(self, lane_idx): assert all((elem in self.left_lane_dict) for elem in self.lane_dict) @@ -42,7 +40,7 @@ class LaneMap: raise ValueError(f"lane_idx {lane_idx} not in lane_dict") left_lane_list = self.left_lane_dict[lane_idx] return copy.deepcopy(left_lane_list[0]) - + def has_right(self, lane_idx): if isinstance(lane_idx, Enum): lane_idx = lane_idx.name @@ -50,7 +48,7 @@ class LaneMap: Warning(f'lane {lane_idx} not available') return False right_lane_list = self.right_lane_dict[lane_idx] - return len(right_lane_list) > 0 + return len(right_lane_list)>0 def right_lane(self, lane_idx): assert all((elem in self.right_lane_dict) for elem in self.lane_dict) @@ -60,34 +58,34 @@ class LaneMap: raise ValueError(f"lane_idx {lane_idx} not in lane_dict") right_lane_list = self.right_lane_dict[lane_idx] return copy.deepcopy(right_lane_list[0]) - + def lane_geometry(self, lane_idx): if isinstance(lane_idx, Enum): lane_idx = lane_idx.name return self.lane_dict[lane_idx].get_geometry() - def get_longitudinal_position(self, lane_idx: str, position: np.ndarray) -> float: + def get_longitudinal_position(self, lane_idx:str, position:np.ndarray) -> float: if not isinstance(position, np.ndarray): position = np.array(position) lane = self.lane_dict[lane_idx] return lane.get_longitudinal_position(position) - def get_lateral_distance(self, lane_idx: str, position: np.ndarray) -> float: + def get_lateral_distance(self, lane_idx:str, position:np.ndarray) -> float: if not isinstance(position, np.ndarray): position = np.array(position) lane = self.lane_dict[lane_idx] return lane.get_lateral_distance(position) - def get_altitude(self, lane_idx, position: np.ndarray) -> float: + def get_altitude(self, lane_idx, position:np.ndarray) -> float: raise NotImplementedError - def get_lane_heading(self, lane_idx: str, position: np.ndarray) -> float: + def get_lane_heading(self, lane_idx:str, position: np.ndarray) -> float: if not isinstance(position, np.ndarray): position = np.array(position) lane = self.lane_dict[lane_idx] return lane.get_heading(position) - def get_lane_segment(self, lane_idx: str, position: np.ndarray) -> AbstractLane: + def get_lane_segment(self, lane_idx:str, position: np.ndarray) -> AbstractLane: if not isinstance(position, np.ndarray): position = np.array(position) lane = self.lane_dict[lane_idx] diff --git a/dryvr_plus_plus/map/lane_segment.py b/dryvr_plus_plus/map/lane_segment.py index db2a32c0003830d1ac34831f12f8a07ee5fd51f6..a145f2d374737c8ad860f12075b078ab83e24732 100644 --- a/dryvr_plus_plus/map/lane_segment.py +++ b/dryvr_plus_plus/map/lane_segment.py @@ -5,7 +5,6 @@ import copy from dryvr_plus_plus.reachability_engine.utils import wrap_to_pi, Vector, get_class_path, to_serializable - class LineType: """A lane side line type.""" @@ -15,7 +14,6 @@ class LineType: CONTINUOUS = 2 CONTINUOUS_LINE = 3 - class AbstractLane(object): """A lane on the road, described by its central curve.""" @@ -27,7 +25,7 @@ class AbstractLane(object): longitudinal_start: float = 0 line_types: List["LineType"] - def __init__(self, id: str): + def __init__(self, id:str): self.id = id self.type = None @@ -139,19 +137,18 @@ class AbstractLane(object): angle = np.abs(wrap_to_pi(heading - self.heading_at(s))) return abs(r) + max(s - self.length, 0) + max(0 - s, 0) + heading_weight*angle - class StraightLane(AbstractLane): """A lane going in straight line.""" def __init__(self, - id: str, + id: str, start: Vector, end: Vector, width: float = AbstractLane.DEFAULT_WIDTH, line_types: Tuple[LineType, LineType] = None, forbidden: bool = False, - speed_limit: List[Tuple[float, float]] = None, + speed_limit: float = 20, priority: int = 0) -> None: """ New straight lane. @@ -167,19 +164,14 @@ class StraightLane(AbstractLane): self.start = np.array(start) self.end = np.array(end) self.width = width - self.heading = np.arctan2( - self.end[1] - self.start[1], self.end[0] - self.start[0]) + self.heading = np.arctan2(self.end[1] - self.start[1], self.end[0] - self.start[0]) self.length = np.linalg.norm(self.end - self.start) self.line_types = line_types or [LineType.STRIPED, LineType.STRIPED] self.direction = (self.end - self.start) / self.length - self.direction_lateral = np.array( - [-self.direction[1], self.direction[0]]) + self.direction_lateral = np.array([-self.direction[1], self.direction[0]]) self.forbidden = forbidden self.priority = priority - if speed_limit != None: - self.speed_limit = sorted(speed_limit, key=lambda elem: elem[0]) - else: - self.speed_limit = None + self.speed_limit = speed_limit self.type = 'Straight' self.longitudinal_start = 0 @@ -198,54 +190,6 @@ class StraightLane(AbstractLane): lateral = np.dot(delta, self.direction_lateral) return float(longitudinal), float(lateral) - def speed_limit_at(self, longitudinal: float) -> float: - # print(self.speed_limit) - if longitudinal >= self.speed_limit[-1][0]: - # print(longitudinal, self.speed_limit[-1][1]) - return self.speed_limit[-1][1] - prev_limit = self.speed_limit[0][1] - for (start, limit) in self.speed_limit: - if longitudinal <= start: - # print(longitudinal, prev_limit) - return prev_limit - prev_limit = limit - - return -1 - # in format for polty filling mode - - def get_all_speed(self): - end_longitudinal, end_lateral = self.local_coordinates(self.end) - ret_x = [] - ret_y = [] - ret_v = [] - x_y = np.ndarray(shape=2) - seg_pos = [] - speed_limit = copy.deepcopy(self.speed_limit) - speed_limit.append(tuple([end_longitudinal, self.speed_limit[-1][1]])) - for i in range(len(self.speed_limit)): - seg_start = speed_limit[i][0] - limit = speed_limit[i][1] - if end_longitudinal < seg_start: - break - seg_pos = [] - seg_end = min(end_longitudinal, speed_limit[i+1][0]) - x_y = self.position(seg_start, self.width/2) - seg_pos.append(x_y.tolist()) - x_y = self.position(seg_end, self.width/2) - seg_pos.append(x_y.tolist()) - x_y = self.position(seg_end, -self.width/2) - seg_pos.append(x_y.tolist()) - x_y = self.position(seg_start, -self.width/2) - seg_pos.append(x_y.tolist()) - ret_x.append([pos[0] for pos in seg_pos]) - ret_y.append([pos[1] for pos in seg_pos]) - ret_v.append(limit) - # print('get_all_speed') - # print(ret_x) - # print(ret_y) - # print(ret_v) - return ret_x, ret_y, ret_v - @classmethod def from_config(cls, config: dict): config["start"] = np.array(config["start"]) @@ -272,7 +216,7 @@ class CircularLane(AbstractLane): """A lane going in circle arc.""" def __init__(self, - id, + id, center: Vector, radius: float, start_phase: float, @@ -281,7 +225,7 @@ class CircularLane(AbstractLane): width: float = AbstractLane.DEFAULT_WIDTH, line_types: List[LineType] = None, forbidden: bool = False, - speed_limit: List[Tuple[float, float]] = None, + speed_limit: float = 20, priority: int = 0) -> None: super().__init__(id) self.center = np.array(center) @@ -295,7 +239,7 @@ class CircularLane(AbstractLane): self.forbidden = forbidden self.length = radius*(end_phase - start_phase) * self.direction self.priority = priority - self.speed_limit = sorted(speed_limit, key=lambda elem: elem[0]) + self.speed_limit = speed_limit self.type = 'Circular' self.longitudinal_start = 0 @@ -320,12 +264,6 @@ class CircularLane(AbstractLane): lateral = self.direction*(self.radius - r) return longitudinal, lateral - def speed_limit_at(self, longitudinal: float) -> float: - for (start, limit) in self.speed_limit: - if longitudinal <= start: - return limit - return -1 - @classmethod def from_config(cls, config: dict): config["center"] = np.array(config["center"]) @@ -350,15 +288,15 @@ class CircularLane(AbstractLane): class LaneSegment: - def __init__(self, id, lane_parameter=None): + def __init__(self, id, lane_parameter = None): self.id = id # self.left_lane:List[str] = left_lane - # self.right_lane:List[str] = right_lane + # self.right_lane:List[str] = right_lane # self.next_segment:int = next_segment - self.lane_parameter = None + self.lane_parameter = None if lane_parameter is not None: self.lane_parameter = lane_parameter def get_geometry(self): - return self.lane_parameter + return self.lane_parameter \ No newline at end of file diff --git a/dryvr_plus_plus/plotter/__init__.py b/dryvr_plus_plus/plotter/__init__.py index 82943711d2fd0ab8bf62d81d5d4546d962ee1715..d1a378c88806ee36e77dc8d20301f60ea778c583 100644 --- a/dryvr_plus_plus/plotter/__init__.py +++ b/dryvr_plus_plus/plotter/__init__.py @@ -20,4 +20,4 @@ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Codes are far away from bugs with the protection -from . import parser, plotter2D, plotter3D +from . import plotter2D, plotter3D diff --git a/dryvr_plus_plus/plotter/parser.py b/dryvr_plus_plus/plotter/parser.py deleted file mode 100644 index e5567fbd046e05a2895aeef5a2e7a7078c74a2ca..0000000000000000000000000000000000000000 --- a/dryvr_plus_plus/plotter/parser.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -This file consist parser code for DryVR reachtube output -""" - -from typing import TextIO -import re - -class Parser: - def __init__(self, f: TextIO): - data = f.readlines() - curr_key = "" - self.data_dict = {} - i = 0 - while i < len(data): - line = data[i] - if not re.match('^[-+0-9+.+0-9+e+0-9 ]+$', line): - self.data_dict[line] = [] - curr_key = line - i += 1 - else: - line_lower = data[i] - line_lower_list = line_lower.split(' ') - line_lower_list = [float(elem) for elem in line_lower_list] - line_upper = data[i+1] - line_upper_list = line_upper.split(' ') - line_upper_list = [float(elem) for elem in line_upper_list] - rect = [line_lower_list, line_upper_list] - self.data_dict[curr_key].append(rect) - i += 2 - - - def get_all_data(self): - res = [] - for key in self.data_dict: - res += self.data_dict[key] - return res - - \ No newline at end of file diff --git a/dryvr_plus_plus/plotter/plotter2D.py b/dryvr_plus_plus/plotter/plotter2D.py index 7f0ac77edcc26aaf27f0a018b902b043ecde8f8c..c6f22b51d2b3a69c9f7b3036fbcb25af22d4807f 100644 --- a/dryvr_plus_plus/plotter/plotter2D.py +++ b/dryvr_plus_plus/plotter/plotter2D.py @@ -5,22 +5,21 @@ This file consist main plotter code for DryVR reachtube output from __future__ import annotations import numpy as np import plotly.graph_objects as go -from typing import List, Tuple +from typing import List, Tuple, Union from plotly.graph_objs.scatter import Marker -# from dryvr_plus_plus import AnalysisTreeNode - -colors = [['#CC0000', '#FF0000', '#FF3333', '#FF6666', '#FF9999'], - ['#CC6600', '#FF8000', '#FF9933', '#FFB266', '#FFCC99'], - ['#CCCC00', '#FFFF00', '#FFFF33', '#FFFF66', '#FFFF99'], - ['#66CC00', '#80FF00', '#99FF33', '#B2FF66', '#CCFF99'], - ['#00CC00', '#00FF00', '#33FF33', '#66FF66', '#99FF99'], - ['#00CC66', '#00FF80', '#33FF99', '#66FFB2', '#99FFCC'], - ['#00CCCC', '#00FFFF', '#33FFFF', '#66FFFF', '#99FFFF'], - ['#0066CC', '#0080FF', '#3399FF', '#66B2FF', '#99CCFF'], - ['#0000CC', '#0000FF', '#3333FF', '#6666FF', '#9999FF'], - ['#6600CC', '#7F00FF', '#9933FF', '#B266FF', '#CC99FF'], - ['#CC00CC', '#FF00FF', '#FF33FF', '#FF66FF', '#FF99FF'], - ['#CC0066', '#FF007F', '#FF3399', '#FF66B2', '#FF99CC'] +from dryvr_plus_plus.analysis.analysis_tree import AnalysisTree, AnalysisTreeNode + +colors = [['#CC0000', '#FF0000', '#FF3333', '#FF6666', '#FF9999', '#FFCCCC'], + ['#CCCC00', '#FFFF00', '#FFFF33', '#FFFF66', '#FFFF99', '#FFE5CC'], + ['#66CC00', '#80FF00', '#99FF33', '#B2FF66', '#CCFF99', '#FFFFCC'], + ['#00CC00', '#00FF00', '#33FF33', '#66FF66', '#99FF99', '#E5FFCC'], + ['#00CC66', '#00FF80', '#33FF99', '#66FFB2', '#99FFCC', '#CCFFCC'], + ['#00CCCC', '#00FFFF', '#33FFFF', '#66FFFF', '#99FFFF', '#CCFFE5'], + ['#0066CC', '#0080FF', '#3399FF', '#66B2FF', '#99CCFF', '#CCE5FF'], + ['#0000CC', '#0000FF', '#3333FF', '#6666FF', '#9999FF', '#CCCCFF'], + ['#6600CC', '#7F00FF', '#9933FF', '#B266FF', '#CC99FF', '#E5CCFF'], + ['#CC00CC', '#FF00FF', '#FF33FF', '#FF66FF', '#FF99FF', '#FFCCFF'], + ['#CC0066', '#FF007F', '#FF3399', '#FF66B2', '#FF99CC', '#FFCCE5'] ] scheme_dict = {'red': 0, 'orange': 1, 'yellow': 2, 'yellowgreen': 3, 'lime': 4, 'springgreen': 5, 'cyan': 6, 'cyanblue': 7, 'blue': 8, 'purple': 9, 'magenta': 10, 'pink': 11} @@ -55,7 +54,9 @@ These 5 functions share the same API. """ -def reachtube_anime(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim: int = 2, map_type='lines', scale_type='trace', print_dim_list=None): +def reachtube_anime(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=go.Figure(), x_dim: int = 1, y_dim: int = 2, map_type='lines', scale_type='trace', print_dim_list=None, label_mode='None'): + if isinstance(root, AnalysisTree): + root = root.root """It gives the animation of the verfication.""" agent_list = list(root.agent.keys()) timed_point_dict = {} @@ -152,25 +153,27 @@ def reachtube_anime(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim: int x_max = max(x_max, max(trace[:, x_dim])) y_min = min(y_min, min(trace[:, y_dim])) y_max = max(y_max, max(trace[:, y_dim])) - if previous_mode[agent_id] != node.mode[agent_id]: - text_pos, text = get_text_pos(node.mode[agent_id][0]) - x0 = trace[0, x_dim] - x1 = trace[1, x_dim] - y0 = trace[0, y_dim] - y1 = trace[1, y_dim] - mode_point_color = colors[agent_list.index(agent_id) % 12][0] - fig.add_trace(go.Scatter(x=[(x0+x1)/2], y=[(y0+y1)/2], - mode='markers+text', - line_color=mode_point_color, - text=str(agent_id)+': ' + text, - textposition=text_pos, - opacity=0.5, - textfont=dict( - size=text_size, - color=mode_text_color), - showlegend=False, - )) - previous_mode[agent_id] = node.mode[agent_id] + if label_mode != 'None': + if previous_mode[agent_id] != node.mode[agent_id]: + text_pos, text = get_text_pos(node.mode[agent_id][0]) + x0 = trace[0, x_dim] + x1 = trace[1, x_dim] + y0 = trace[0, y_dim] + y1 = trace[1, y_dim] + mode_point_color = colors[agent_list.index( + agent_id) % 12][0] + fig.add_trace(go.Scatter(x=[(x0+x1)/2], y=[(y0+y1)/2], + mode='markers+text', + line_color=mode_point_color, + text=str(agent_id)+': ' + text, + textposition=text_pos, + opacity=0.5, + textfont=dict( + size=text_size, + color=mode_text_color), + showlegend=False, + )) + previous_mode[agent_id] = node.mode[agent_id] queue += node.child if scale_type == 'trace': fig.update_xaxes( @@ -180,7 +183,9 @@ def reachtube_anime(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim: int return fig -def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace', print_dim_list=None): +def reachtube_tree(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace', print_dim_list=None, label_mode='None'): + if isinstance(root, AnalysisTree): + root = root.root """It statically shows all the traces of the verfication.""" fig = draw_map(map=map, fig=fig, fill_type=map_type) agent_list = list(root.agent.keys()) @@ -209,7 +214,7 @@ def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map while queue != []: node = queue.pop(0) traces = node.trace - print({k: len(v) for k, v in traces.items()}) + # print({k: len(v) for k, v in traces.items()}) i = 0 for agent_id in traces: trace = np.array(traces[agent_id]) @@ -219,21 +224,23 @@ def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map y_min = min(y_min, min(trace[:, y_dim])) y_max = max(y_max, max(trace[:, y_dim])) i = agent_list.index(agent_id) - if previous_mode[agent_id] != node.mode[agent_id]: - text_pos, text = get_text_pos(node.mode[agent_id][0]) - mode_point_color = colors[agent_list.index(agent_id) % 12][0] - fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]], - mode='markers+text', - line_color=mode_point_color, - opacity=0.5, - text=str(agent_id)+': ' + text, - textposition=text_pos, - textfont=dict( - size=text_size, - color=mode_text_color), - showlegend=False, - )) - previous_mode[agent_id] = node.mode[agent_id] + if label_mode != 'None': + if previous_mode[agent_id] != node.mode[agent_id]: + text_pos, text = get_text_pos(node.mode[agent_id][0]) + mode_point_color = colors[agent_list.index( + agent_id) % 12][0] + fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]], + mode='markers+text', + line_color=mode_point_color, + opacity=0.5, + text=str(agent_id)+': ' + text, + textposition=text_pos, + textfont=dict( + size=text_size, + color=mode_text_color), + showlegend=False, + )) + previous_mode[agent_id] = node.mode[agent_id] queue += node.child if scale_type == 'trace': fig.update_xaxes( @@ -243,7 +250,8 @@ def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map return fig -def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace', print_dim_list=None): +def simulation_tree(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace', print_dim_list=None, label_mode='None'): + root = root.root """It statically shows all the traces of the simulation.""" fig = draw_map(map=map, fig=fig, fill_type=map_type) agent_list = list(root.agent.keys()) @@ -316,7 +324,9 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type= return fig -def simulation_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace', print_dim_list=None): +def simulation_anime(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace', print_dim_list=None, label_mode='None'): + if isinstance(root, AnalysisTree): + root = root.root """It gives the animation of the simulation without trail but is faster.""" timed_point_dict = {} stack = [root] @@ -439,20 +449,21 @@ def simulation_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type for agent_id in traces: trace = np.array(traces[agent_id]) mode_point_color = colors[agent_list.index(agent_id) % 12][0] - if previous_mode[agent_id] != node.mode[agent_id]: - text_pos, text = get_text_pos(node.mode[agent_id][0]) - fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]], - mode='markers+text', - line_color=mode_point_color, - text=str(agent_id)+': ' + text, - textposition=text_pos, - opacity=0.5, - textfont=dict( - size=text_size, - color=mode_text_color), - showlegend=False, - )) - previous_mode[agent_id] = node.mode[agent_id] + if label_mode != 'None': + if previous_mode[agent_id] != node.mode[agent_id]: + text_pos, text = get_text_pos(node.mode[agent_id][0]) + fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]], + mode='markers+text', + line_color=mode_point_color, + text=str(agent_id)+': ' + text, + textposition=text_pos, + opacity=0.5, + textfont=dict( + size=text_size, + color=mode_text_color), + showlegend=False, + )) + previous_mode[agent_id] = node.mode[agent_id] queue += node.child if scale_type == 'trace': fig.update_xaxes( @@ -463,8 +474,10 @@ def simulation_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type return fig -def simulation_anime_trail(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace', print_dim_list=None): +def simulation_anime_trail(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace', print_dim_list=None, label_mode='None'): """It gives the animation of the simulation with trail.""" + if isinstance(root, AnalysisTree): + root = root.root timed_point_dict = {} stack = [root] x_min, x_max = float('inf'), -float('inf') @@ -615,21 +628,22 @@ def simulation_anime_trail(root, map=None, fig=go.Figure(), x_dim: int = 1, y_di trace_x = trace[:, x_dim].tolist() i = agent_list.index(agent_id) mode_point_color = colors[agent_list.index(agent_id) % 12][0] - if previous_mode[agent_id] != node.mode[agent_id]: - text_pos, text = get_text_pos(node.mode[agent_id][0]) - fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]], - mode='markers+text', - line_color=mode_point_color, - text=str(agent_id)+': ' + text, - opacity=0.5, - textposition=text_pos, - textfont=dict( - size=text_size, - color=mode_text_color - ), - showlegend=False, - )) - previous_mode[agent_id] = node.mode[agent_id] + if label_mode != 'None': + if previous_mode[agent_id] != node.mode[agent_id]: + text_pos, text = get_text_pos(node.mode[agent_id][0]) + fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]], + mode='markers+text', + line_color=mode_point_color, + text=str(agent_id)+': ' + text, + opacity=0.5, + textposition=text_pos, + textfont=dict( + size=text_size, + color=mode_text_color + ), + showlegend=False, + )) + previous_mode[agent_id] = node.mode[agent_id] queue += node.child if scale_type == 'trace': fig.update_xaxes( @@ -643,7 +657,9 @@ def simulation_anime_trail(root, map=None, fig=go.Figure(), x_dim: int = 1, y_di """Functions below are low-level functions and usually are not called outside this file.""" -def reachtube_tree_single(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None): +def reachtube_tree_single(root: Union[AnalysisTree, AnalysisTreeNode], agent_id, fig=go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None): + if isinstance(root, AnalysisTree): + root = root.root """It statically shows the verfication traces of one given agent.""" global color_cnt if color == None: @@ -651,7 +667,7 @@ def reachtube_tree_single(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim color_cnt = (color_cnt+1) % 12 queue = [root] show_legend = False - fillcolor = colors[scheme_dict[color]][4] + fillcolor = colors[scheme_dict[color]][5] while queue != []: node = queue.pop(0) traces = node.trace @@ -663,57 +679,79 @@ def reachtube_tree_single(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim trace_y_odd = np.array([trace[i][y_dim] for i in range(0, max_id, 2)]) trace_y_even = np.array([trace[i][y_dim] for i in range(1, max_id+1, 2)]) - fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines', + # trace_y_new = [0]*len(trace) + # trace_y_new[0] = trace[0][y_dim] + # trace_y_new[int(len(trace)/2)] = trace[-1][y_dim] + # for i in range(1, max_id, 2): + # if trace[i][y_dim] > trace[i+1][y_dim]: + # trace_y_new[i] = trace[i][y_dim] + # trace_y_new[max_id+1-i] = trace[i+1][y_dim] + # else: + # trace_y_new[i] = trace[i+1][y_dim] + # trace_y_new[max_id+1-i] = trace[i][y_dim] + # trace_y_new=np.array(trace_y_new) + # fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_even[::-1].tolist(), y=trace_y_new, mode='lines', + # fill='toself', + # fillcolor=fillcolor, + # opacity=0.5, + # line_color='rgba(255,255,255,0)', + # showlegend=show_legend + # )) + # fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines', + # fill='toself', + # fillcolor=fillcolor, + # opacity=0.5, + # line_color='rgba(255,255,255,0)', + # showlegend=show_legend + # )) + # fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines', + # fill='toself', + # fillcolor=fillcolor, + # opacity=0.5, + # line_color='rgba(255,255,255,0)', + # showlegend=show_legend)) + fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_even[::-1].tolist()+[trace_x_odd[0]], y=trace_y_odd.tolist()+trace_y_even[::-1].tolist()+[trace_y_odd[0]], mode='markers+lines', fill='toself', fillcolor=fillcolor, - opacity=0.5, - line_color='rgba(255,255,255,0)', + # opacity=0.5, + marker={'size': 1}, + line_color=colors[scheme_dict[color]][2], + line={'width': 1}, showlegend=show_legend )) - fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines', - fill='toself', - fillcolor=fillcolor, - opacity=0.5, - line_color='rgba(255,255,255,0)', - showlegend=show_legend)) - fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines', - fill='toself', - fillcolor=fillcolor, - opacity=0.5, - line_color='rgba(255,255,255,0)', - showlegend=show_legend - )) - fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines', - fill='toself', - fillcolor=fillcolor, - opacity=0.5, - line_color='rgba(255,255,255,0)', - showlegend=show_legend)) + # fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines', + # fill='toself', + # fillcolor=fillcolor, + # opacity=0.5, + # line_color='rgba(255,255,255,0)', + # showlegend=show_legend)) queue += node.child - queue = [root] - while queue != []: - node = queue.pop(0) - traces = node.trace - trace = np.array(traces[agent_id]) - max_id = len(trace)-1 - fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim], - mode='markers', - text=[ - ['{:.2f}'.format(trace[i, j])for j in print_dim_list] for i in range(0, trace.shape[0])], - line_color=colors[scheme_dict[color]][0], - marker={ - "sizemode": "area", - "sizeref": 200000, - "size": 2 - }, - name='lines', - showlegend=False)) - queue += node.child + # queue = [root] + # while queue != []: + # node = queue.pop(0) + # traces = node.trace + # trace = np.array(traces[agent_id]) + # max_id = len(trace)-1 + # fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim], + # mode='markers', + # text=[ + # ['{:.2f}'.format(trace[i, j])for j in print_dim_list] for i in range(0, trace.shape[0])], + # line_color=colors[scheme_dict[color]][0], + # marker={ + # "sizemode": "area", + # "sizeref": 200000, + # "size": 2 + # }, + # name='lines', + # showlegend=False)) + # queue += node.child return fig -def simulation_tree_single(root, agent_id, fig: go.Figure = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None): +def simulation_tree_single(root: Union[AnalysisTree, AnalysisTreeNode], agent_id, fig: go.Figure = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None): + if isinstance(root, AnalysisTree): + root = root.root """It statically shows the simulation traces of one given agent.""" global color_cnt queue = [root] @@ -741,17 +779,32 @@ def simulation_tree_single(root, agent_id, fig: go.Figure = go.Figure(), x_dim: count_dict[time] = 1 start_list.append(start) end_list.append(end) + fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim], mode='lines', line_color=colors[scheme_dict[color] ][color_id], text=[ - ['{:.2f}'.format(trace[i, j])for j in print_dim_list] for i in range(trace.shape[0])], - legendgroup=agent_id, - legendgrouptitle_text=agent_id, - name=str(round(start[0], 2))+'-'+str(round(end[0], 2)) + - '-'+str(count_dict[time]), - showlegend=True)) + ['{:.2f}'.format(trace[i, j])for j in print_dim_list] for i in range(trace.shape[0])], + legendgroup=agent_id, + legendgrouptitle_text=agent_id, + name=str(round(start[0], 2))+'-'+str(round(end[0], 2)) + + '-'+str(count_dict[time]), + showlegend=True)) + if node.assert_hits != None and agent_id in node.assert_hits: + fig.add_trace(go.Scatter(x=[trace[-1, x_dim]], y=[trace[-1, y_dim]], + mode='markers+text', + # line_color='grey', + text=['HIT:\n' + + a for a in node.assert_hits[agent_id]], + textfont={'color': 'grey'}, + legendgroup=agent_id, + marker={'size': 4, 'color': 'black'}, + legendgrouptitle_text=agent_id, + name=str(round(start[0], 2))+'-'+str(round(end[0], 2)) + + '-'+str(count_dict[time])+'hit', + showlegend=True)) + color_id = (color_id+4) % 5 queue += node.child fig.update_layout(legend=dict( @@ -981,6 +1034,6 @@ def get_text_pos(veh_mode): text_pos = 'middle center' text = 'S' else: - print(veh_mode) - raise ValueError + text_pos = 'middle center' + text = veh_mode return text_pos, text diff --git a/dryvr_plus_plus/reachability_engine/dryvr.py b/dryvr_plus_plus/reachability_engine/dryvr.py index e480fe94f76605acb3a91c22261dda399f1b8a98..d7e27985412cf744bb8ccc5e6609521b3288224b 100644 --- a/dryvr_plus_plus/reachability_engine/dryvr.py +++ b/dryvr_plus_plus/reachability_engine/dryvr.py @@ -41,15 +41,13 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray trace_initial_time = center_trace[0, 0] x_points: np.ndarray = center_trace[:, 0] - trace_initial_time assert np.all(training_traces[0, :, 0] == training_traces[1:, :, 0]) - y_points: np.ndarray = all_sensitivities_calc( - training_traces, initial_radii) + y_points: np.ndarray = all_sensitivities_calc(training_traces, initial_radii) points: np.ndarray = np.zeros((ndims - 1, trace_len, 2)) points[np.where(initial_radii != 0), 0, 1] = 1.0 points[:, :, 0] = np.reshape(x_points, (1, x_points.shape[0])) points[:, 1:, 1] = y_points normalizing_initial_set_radii: np.ndarray = initial_radii.copy() - normalizing_initial_set_radii[np.where( - normalizing_initial_set_radii == 0)] = 1.0 + normalizing_initial_set_radii[np.where(normalizing_initial_set_radii == 0)] = 1.0 df: np.ndarray = np.zeros((trace_len, ndims)) if method == 'PW': df[:, 1:] = np.transpose( @@ -61,8 +59,7 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray points[:, :, 1] = np.maximum(points[:, :, 1], _EPSILON) points[:, :, 1] = np.log(points[:, :, 1]) for dim_ind in range(1, ndims): - new_min = min( - np.min(points[dim_ind - 1, 1:, 1]) + _TRUE_MIN_CONST, -10) + new_min = min(np.min(points[dim_ind - 1, 1:, 1]) + _TRUE_MIN_CONST, -10) if initial_radii[dim_ind - 1] == 0: # exclude initial set, then add true minimum points new_points: np.ndarray = np.row_stack( @@ -70,8 +67,7 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray else: # start from zero, then add true minimum points new_points: np.ndarray = np.row_stack((points[dim_ind - 1, 0, :], - np.array( - (points[dim_ind - 1, 0, 0], new_min)), + np.array((points[dim_ind - 1, 0, 0], new_min)), np.array((points[dim_ind - 1, -1, 0], new_min)))) df[0, dim_ind] = initial_radii[dim_ind - 1] # Tuple order is start_time, end_time, slope, y-intercept @@ -87,18 +83,15 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray for end_ind, start_ind in vert_inds: if cur_dim_points[start_ind, 1] != new_min and cur_dim_points[end_ind, 1] != new_min: slope = (cur_dim_points[end_ind, 1] - cur_dim_points[start_ind, 1]) / ( - cur_dim_points[end_ind, 0] - cur_dim_points[start_ind, 0]) - y_intercept = cur_dim_points[start_ind, - 1] - cur_dim_points[start_ind, 0] * slope + cur_dim_points[end_ind, 0] - cur_dim_points[start_ind, 0]) + y_intercept = cur_dim_points[start_ind, 1] - cur_dim_points[start_ind, 0] * slope start_time = cur_dim_points[start_ind, 0] end_time = cur_dim_points[end_ind, 0] assert start_time < end_time if start_time == 0: - linear_separators.append( - (start_time, end_time, slope, y_intercept, 0, end_ind + 1)) + linear_separators.append((start_time, end_time, slope, y_intercept, 0, end_ind + 1)) else: - linear_separators.append( - (start_time, end_time, slope, y_intercept, start_ind + 1, end_ind + 1)) + linear_separators.append((start_time, end_time, slope, y_intercept, start_ind + 1, end_ind + 1)) linear_separators.sort() prev_val = 0 prev_ind = 1 if initial_radii[dim_ind - 1] == 0 else 0 @@ -118,17 +111,13 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray raise ValueError assert (np.all(df >= 0)) reachtube_segment: np.ndarray = np.zeros((trace_len - 1, 2, ndims)) - reachtube_segment[:, 0, :] = np.minimum( - center_trace[1:, :] - df[1:, :], center_trace[:-1, :] - df[:-1, :]) - reachtube_segment[:, 1, :] = np.maximum( - center_trace[1:, :] + df[1:, :], center_trace[:-1, :] + df[:-1, :]) + reachtube_segment[:, 0, :] = np.minimum(center_trace[1:, :] - df[1:, :], center_trace[:-1, :] - df[:-1, :]) + reachtube_segment[:, 1, :] = np.maximum(center_trace[1:, :] + df[1:, :], center_trace[:-1, :] + df[:-1, :]) # assert 100% training accuracy (all trajectories are contained) for trace_ind in range(training_traces.shape[0]): if not (np.all(reachtube_segment[:, 0, :] <= training_traces[trace_ind, 1:, :]) and np.all(reachtube_segment[:, 1, :] >= training_traces[trace_ind, 1:, :])): - assert np.any( - np.abs(training_traces[trace_ind, 0, 1:]-center_trace[0, 1:]) > initial_radii) - print(f"Warning: Trace #{trace_ind}", - "of this initial set is sampled outside of the initial set because of floating point error and is not contained in the initial set") + assert np.any(np.abs(training_traces[trace_ind, 0, 1:]-center_trace[0, 1:]) > initial_radii) + print(f"Warning: Trace #{trace_ind}", "of this initial set is sampled outside of the initial set because of floating point error and is not contained in the initial set") return reachtube_segment def calcCenterPoint(lower, upper): diff --git a/dryvr_plus_plus/scenario/scenario.py b/dryvr_plus_plus/scenario/scenario.py index 640899ea132864502d374632636765a96f3aeaa5..0fd9905fb1a49c1aed61dc09e9849b9f267b1bff 100644 --- a/dryvr_plus_plus/scenario/scenario.py +++ b/dryvr_plus_plus/scenario/scenario.py @@ -9,13 +9,14 @@ import numpy as np from dryvr_plus_plus.agents.base_agent import BaseAgent from dryvr_plus_plus.automaton import GuardExpressionAst, ResetExpression -from dryvr_plus_plus.analysis import Simulator, Verifier, AnalysisTreeNode +from dryvr_plus_plus.analysis import Simulator, Verifier, AnalysisTreeNode, AnalysisTree from dryvr_plus_plus.reachability_engine.utils import find, sample_rect from dryvr_plus_plus.sensor.base_sensor import BaseSensor from dryvr_plus_plus.map.lane_map import LaneMap EGO, OTHERS = "ego", "others" + class Scenario: def __init__(self): self.agent_dict = {} @@ -53,11 +54,14 @@ class Scenario: # agent.controller.vertices = list(itertools.product(*mode_vals)) # agent.controller.vertexStrings = [','.join(elem) for elem in agent.controller.vertices] - def set_init(self, init_list, init_mode_list, static_list = []): + def set_init(self, init_list, init_mode_list, static_list=[]): assert len(init_list) == len(self.agent_dict) assert len(init_mode_list) == len(self.agent_dict) - assert len(static_list) == len(self.agent_dict) or len(static_list) == 0 - for i,agent_id in enumerate(self.agent_dict.keys()): + assert len(static_list) == len( + self.agent_dict) or len(static_list) == 0 + print(init_mode_list) + print(type(init_mode_list)) + for i, agent_id in enumerate(self.agent_dict.keys()): self.init_dict[agent_id] = copy.deepcopy(init_list[i]) self.init_mode_dict[agent_id] = copy.deepcopy(init_mode_list[i]) if static_list: @@ -72,7 +76,7 @@ class Scenario: res_list.append(trace) return res_list - def simulate(self, time_horizon, time_step): + def simulate(self, time_horizon, time_step) -> AnalysisTree: init_list = [] init_mode_list = [] static_list = [] @@ -85,7 +89,7 @@ class Scenario: print(init_list) return self.simulator.simulate(init_list, init_mode_list, static_list, agent_list, self, time_horizon, time_step, self.map) - def verify(self, time_horizon, time_step): + def verify(self, time_horizon, time_step) -> AnalysisTree: init_list = [] init_mode_list = [] static_list = [] @@ -116,7 +120,7 @@ class Scenario: # The reset_list here are all the resets for a single transition. Need to evaluate each of them # and then combine them together for reset_tuple in reset_list: - reset, disc_var_dict, cont_var_dict = reset_tuple + reset, disc_var_dict, cont_var_dict, _ = reset_tuple reset_variable = reset.var expr = reset.expr # First get the transition destinations @@ -127,7 +131,8 @@ class Scenario: found = True break if not found: - raise ValueError(f'Reset discrete variable {discrete_variable_ego} not found') + raise ValueError( + f'Reset discrete variable {discrete_variable_ego} not found') if isinstance(reset.val_ast, ast.Constant): val = eval(expr) possible_dest[var_loc] = [val] @@ -175,7 +180,7 @@ class Scenario: for comb in comb_list: val_dict = {} tmp = copy.deepcopy(expr) - for symbol_idx,symbol in enumerate(symbols): + for symbol_idx, symbol in enumerate(symbols): tmp = tmp.replace(symbol, str(comb[symbol_idx])) res = eval(tmp, {}, val_dict) lb = min(lb, res) @@ -198,7 +203,7 @@ class Scenario: comb_list = list(itertools.product(*data_list)) return comb_list - def apply_cont_var_updater(self,cont_var_dict, updater): + def apply_cont_var_updater(self, cont_var_dict, updater): for variable in updater: for unrolled_variable, unrolled_variable_index in updater[variable]: cont_var_dict[unrolled_variable] = cont_var_dict[variable][unrolled_variable_index] @@ -217,7 +222,7 @@ class Scenario: for agent_id in node.agent: # Get guard - agent:BaseAgent = self.agent_dict[agent_id] + agent: BaseAgent = self.agent_dict[agent_id] agent_mode = node.mode[agent_id] if len(agent.controller.args) == 0: continue @@ -235,20 +240,23 @@ class Scenario: satisfied_guard = [] asserts = defaultdict(list) for agent_id in agent_guard_dict: - agent:BaseAgent = self.agent_dict[agent_id] + agent: BaseAgent = self.agent_dict[agent_id] state_dict = {} for tmp in node.agent: - state_dict[tmp] = (node.trace[tmp][idx], node.mode[tmp], node.static[tmp]) + state_dict[tmp] = (node.trace[tmp][idx], + node.mode[tmp], node.static[tmp]) agent_state, agent_mode, agent_static = state_dict[agent_id] agent_state = agent_state[1:] - continuous_variable_dict, orig_disc_vars, _ = self.sensor.sense(self, agent, state_dict, self.map) + continuous_variable_dict, orig_disc_vars, _ = self.sensor.sense( + self, agent, state_dict, self.map) # Unsafety checking ego_ty_name = find(agent.controller.args, lambda a: a.name == EGO).typ def pack_env(agent: BaseAgent, cont, disc, map): env = copy.deepcopy(cont) env.update(disc) - state_ty = namedtuple(ego_ty_name, agent.controller.state_defs[ego_ty_name].all_vars()) + state_ty = namedtuple( + ego_ty_name, agent.controller.state_defs[ego_ty_name].all_vars()) packed: DefaultDict[str, Any] = defaultdict(dict) for k, v in env.items(): k = k.split(".") @@ -300,7 +308,7 @@ class Scenario: for reset_var in all_resets: iter_list.append(range(len(all_resets[reset_var]))) pos_list = list(itertools.product(*iter_list)) - if len(pos_list)==1 and pos_list[0]==(): + if len(pos_list) == 1 and pos_list[0] == (): continue for i in range(len(pos_list)): pos = pos_list[i] @@ -321,15 +329,18 @@ class Scenario: next_init[var_loc] = res all_dest = list(itertools.product(*possible_dest)) if not all_dest: - warnings.warn(f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode") + warnings.warn( + f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode") all_dest.append(None) for dest in all_dest: - satisfied_guard.append((agent_id, agent_mode, dest, next_init)) + satisfied_guard.append( + (agent_id, agent_mode, dest, next_init)) if len(asserts) > 0: return asserts, transitions, idx if len(satisfied_guard) > 0: for agent_idx, src_mode, dest_mode, next_init in satisfied_guard: - transitions[agent_idx].append((agent_idx, src_mode, dest_mode, next_init, idx)) + transitions[agent_idx].append( + (agent_idx, src_mode, dest_mode, next_init, idx)) break return None, transitions, idx @@ -349,10 +360,10 @@ class Scenario: cont_var_dict_template, discrete_variable_dict, length_dict = self.sensor.sense(self, agent, state_dict, self.map) # TODO-PARSER: Get equivalent for this function paths = agent.controller.paths - for path in paths: + for guard_idx, path in enumerate(paths): # Construct the guard expression reset = (path.var, path.val_veri) - guard_expression = GuardExpressionAst([path.cond_veri]) + guard_expression = GuardExpressionAst([path.cond_veri], guard_idx) cont_var_updater = guard_expression.parse_any_all_new(cont_var_dict_template, discrete_variable_dict, length_dict) self.apply_cont_var_updater(cont_var_dict_template, cont_var_updater) @@ -409,23 +420,33 @@ class Scenario: new_cont_var_dict = copy.deepcopy(cont_vars) one_step_guard:GuardExpressionAst = copy.deepcopy(guard_expression) - self.apply_cont_var_updater(new_cont_var_dict, continuous_variable_updater) - guard_can_satisfied = one_step_guard.evaluate_guard_hybrid(agent, discrete_variable_dict, new_cont_var_dict, self.map) + self.apply_cont_var_updater( + new_cont_var_dict, continuous_variable_updater) + guard_can_satisfied = one_step_guard.evaluate_guard_hybrid( + agent, discrete_variable_dict, new_cont_var_dict, self.map) if not guard_can_satisfied: continue - guard_satisfied, is_contained = one_step_guard.evaluate_guard_cont(agent, new_cont_var_dict, self.map) + guard_satisfied, is_contained = one_step_guard.evaluate_guard_cont( + agent, new_cont_var_dict, self.map) any_contained = any_contained or is_contained # TODO: Can we also store the cont and disc var dict so we don't have to call sensor again? if guard_satisfied: reset_expr = ResetExpression(reset) - resets[reset_expr.var].append((reset_expr, discrete_variable_dict, new_cont_var_dict)) + resets[reset_expr.var].append( + (reset_expr, discrete_variable_dict, new_cont_var_dict, guard_expression.guard_idx) + ) # Perform combination over all possible resets to generate all possible real resets combined_reset_list = list(itertools.product(*resets.values())) - if len(combined_reset_list)==1 and combined_reset_list[0]==(): + if len(combined_reset_list) == 1 and combined_reset_list[0] == (): continue for i in range(len(combined_reset_list)): + # Compute reset_idx + reset_idx = [] + for reset_info in combined_reset_list[i]: + reset_idx.append(reset_info[3]) # a list of reset expression - hits.append((agent_id, combined_reset_list[i])) + hits.append((agent_id, tuple(reset_idx), + combined_reset_list[i])) if len(asserts) > 0: return (asserts, idx), None if hits != []: @@ -439,7 +460,7 @@ class Scenario: reset_dict = {}#defaultdict(lambda: defaultdict(list)) reset_idx_dict = {}#defaultdict(lambda: defaultdict(list)) for hits, all_agent_state, hit_idx in guard_hits: - for agent_id, reset_list in hits: + for agent_id, reset_idx, reset_list in hits: # TODO: Need to change this function to handle the new reset expression and then I am done dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state) if agent_id not in reset_dict: @@ -449,30 +470,36 @@ class Scenario: warnings.warn( f"Guard hit for mode {node.mode[agent_id]} for agent {agent_id} without available next mode") dest_list.append(None) + if reset_idx not in reset_dict[agent_id]: + reset_dict[agent_id][reset_idx] = {} + reset_idx_dict[agent_id][reset_idx] = {} for dest in dest_list: - if dest not in reset_dict[agent_id]: - reset_dict[agent_id][dest] = [] - reset_idx_dict[agent_id][dest] = [] - reset_dict[agent_id][dest].append(reset_rect) - reset_idx_dict[agent_id][dest].append(hit_idx) + if dest not in reset_dict[agent_id][reset_idx]: + reset_dict[agent_id][reset_idx][dest] = [] + reset_idx_dict[agent_id][reset_idx][dest] = [] + reset_dict[agent_id][reset_idx][dest].append(reset_rect) + reset_idx_dict[agent_id][reset_idx][dest].append(hit_idx) possible_transitions = [] # Combine reset rects and construct transitions for agent in reset_dict: - for dest in reset_dict[agent]: - combined_rect = None - for rect in reset_dict[agent][dest]: - rect = np.array(rect) - if combined_rect is None: - combined_rect = rect - else: - combined_rect[0, :] = np.minimum( - combined_rect[0, :], rect[0, :]) - combined_rect[1, :] = np.maximum( - combined_rect[1, :], rect[1, :]) - combined_rect = combined_rect.tolist() - min_idx = min(reset_idx_dict[agent][dest]) - max_idx = max(reset_idx_dict[agent][dest]) - possible_transitions.append((agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx))) + for reset_idx in reset_dict[agent]: + for dest in reset_dict[agent][reset_idx]: + combined_rect = None + for rect in reset_dict[agent][reset_idx][dest]: + rect = np.array(rect) + if combined_rect is None: + combined_rect = rect + else: + combined_rect[0, :] = np.minimum( + combined_rect[0, :], rect[0, :]) + combined_rect[1, :] = np.maximum( + combined_rect[1, :], rect[1, :]) + combined_rect = combined_rect.tolist() + min_idx = min(reset_idx_dict[agent][reset_idx][dest]) + max_idx = max(reset_idx_dict[agent][reset_idx][dest]) + transition = ( + agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx)) + possible_transitions.append(transition) # Return result return None, possible_transitions diff --git a/dryvr_plus_plus/sensor/base_sensor.py b/dryvr_plus_plus/sensor/base_sensor.py index b668a38d02ef08686ab9fb78422a49db2f6dbceb..3be85df796899833883e3d796addf04415fd78fb 100644 --- a/dryvr_plus_plus/sensor/base_sensor.py +++ b/dryvr_plus_plus/sensor/base_sensor.py @@ -1,51 +1,59 @@ import numpy as np from dryvr_plus_plus.agents.base_agent import BaseAgent + def sets(d, thing, attrs, vals): d.update({thing + "." + k: v for k, v in zip(attrs, vals)}) + def adds(d, thing, attrs, vals): for k, v in zip(attrs, vals): if thing + '.' + k not in d: - d[thing + '.' + k] = [v] + d[thing + '.' + k] = [v] else: - d[thing + '.' + k].append(v) + d[thing + '.' + k].append(v) + def set_states_2d(cnts, disc, thing, val, cont_var, disc_var, stat_var): state, mode, static = val - sets(cnts, thing, cont_var, state[1:5]) + sets(cnts, thing, cont_var, state[1:]) sets(disc, thing, disc_var, mode) sets(disc, thing, stat_var, static) + def set_states_3d(cnts, disc, thing, val, cont_var, disc_var, stat_var): state, mode, static = val - transp = np.transpose(np.array(state)[:, 1:5]) - assert len(transp) == 4 + transp = np.transpose(np.array(state)[:, 1:]) + # assert len(transp) == 4 sets(cnts, thing, cont_var, transp) sets(disc, thing, disc_var, mode) sets(disc, thing, stat_var, static) + def add_states_2d(cont, disc, thing, val, cont_var, disc_var, stat_var): state, mode, static = val - adds(cont, thing, cont_var, state[1:5]) + adds(cont, thing, cont_var, state[1:]) adds(disc, thing, disc_var, mode) adds(disc, thing, stat_var, static) + def add_states_3d(cont, disc, thing, val, cont_var, disc_var, stat_var): state, mode, static = val - transp = np.transpose(np.array(state)[:, 1:5]) + transp = np.transpose(np.array(state)[:, 1:]) assert len(transp) == 4 adds(cont, thing, cont_var, transp) adds(disc, thing, disc_var, mode) adds(disc, thing, stat_var, static) # TODO-PARSER: Update base sensor + + class BaseSensor(): # The baseline sensor is omniscient. Each agent can get the state of all other agents - def sense(self, scenario, agent:BaseAgent, state_dict, lane_map): + def sense(self, scenario, agent: BaseAgent, state_dict, lane_map): cont = {} disc = {} - len_dict = {'others':len(state_dict)-1} + len_dict = {'others': len(state_dict)-1} tmp = np.array(list(state_dict.values())[0][0]) if tmp.ndim < 2: for agent_id in state_dict: @@ -58,11 +66,13 @@ class BaseSensor(): arg_type = arg.typ break if arg_type is None: + continue raise ValueError(f"Invalid arg for ego") - cont_var = agent.controller.state_defs[arg_type].cont + cont_var = agent.controller.state_defs[arg_type].cont disc_var = agent.controller.state_defs[arg_type].disc stat_var = agent.controller.state_defs[arg_type].static - set_states_2d(cont, disc, 'ego', state_dict[agent_id], cont_var, disc_var, stat_var) + set_states_2d( + cont, disc, 'ego', state_dict[agent_id], cont_var, disc_var, stat_var) else: controller_args = agent.controller.args arg_type = None @@ -73,11 +83,13 @@ class BaseSensor(): arg_type = arg.typ break if arg_type is None: + continue raise ValueError(f"Invalid arg for others") - cont_var = agent.controller.state_defs[arg_type].cont + cont_var = agent.controller.state_defs[arg_type].cont disc_var = agent.controller.state_defs[arg_type].disc stat_var = agent.controller.state_defs[arg_type].static - add_states_2d(cont, disc, arg_name, state_dict[agent_id], cont_var, disc_var, stat_var) + add_states_2d( + cont, disc, arg_name, state_dict[agent_id], cont_var, disc_var, stat_var) else: for agent_id in state_dict: @@ -91,10 +103,11 @@ class BaseSensor(): break if arg_type is None: raise ValueError(f"Invalid arg for ego") - cont_var = agent.controller.state_defs[arg_type].cont + cont_var = agent.controller.state_defs[arg_type].cont disc_var = agent.controller.state_defs[arg_type].disc stat_var = agent.controller.state_defs[arg_type].static - set_states_3d(cont, disc, 'ego', state_dict[agent_id], cont_var, disc_var, stat_var) + set_states_3d( + cont, disc, 'ego', state_dict[agent_id], cont_var, disc_var, stat_var) else: controller_args = agent.controller.args arg_type = None @@ -106,7 +119,7 @@ class BaseSensor(): break if arg_type is None: raise ValueError(f"Invalid arg for others") - cont_var = agent.controller.state_defs[arg_type].cont + cont_var = agent.controller.state_defs[arg_type].cont disc_var = agent.controller.state_defs[arg_type].disc stat_var = agent.controller.state_defs[arg_type].static add_states_3d(cont, disc, arg_name, state_dict[agent_id], cont_var, disc_var, stat_var) diff --git a/requirements.txt b/requirements.txt index 49471cda4eb5d2a82085ba22f7ee57e43e27093b..01eb1fc3837f4aa68c23b514e282c1e00c9329c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy~=1.22.1 -scipy~=1.6.1 +scipy~=1.8.1 matplotlib~=3.4.2 polytope~=0.2.3 pyvista~=0.32.1 diff --git a/setup.py b/setup.py index 4ba70ef566fbbb5136be5748b815b81612cd48fc..8790ceb50a423984a4efeeb318219cdc8d6ee216 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setup( python_requires='>=3.8', install_requires=[ "numpy~=1.22.1", - "scipy~=1.6.1", + "scipy~=1.8.1", "matplotlib~=3.4.2", "polytope~=0.2.3", "pyvista~=0.32.1",