diff --git a/dryvr_plus_plus/example/example_agent/ball_agent.py b/dryvr_plus_plus/example/example_agent/ball_agent.py index 87bcbe382f92ff46ae223966201c0b9f173357ea..72157f50ceea1de14c6a6414aab96cbb0af16a10 100644 --- a/dryvr_plus_plus/example/example_agent/ball_agent.py +++ b/dryvr_plus_plus/example/example_agent/ball_agent.py @@ -17,6 +17,7 @@ class BallAgent(BaseAgent): @staticmethod def dynamic(t, state, u): + '''Defines the RHS of the ODE used to simulate trajectories''' x, y, vx, vy = state x_dot = vx y_dot = vy @@ -24,34 +25,9 @@ class BallAgent(BaseAgent): vy_dot = 0 return [x_dot, y_dot, vx_dot, vy_dot] - def action_handler(self, mode: List[str], state, lane_map:LaneMap)->Tuple[float, float]: - x,y,theta,v = state - vehicle_mode = mode[0] - vehicle_lane = mode[1] - vehicle_pos = np.array([x,y]) - a = 0 - if vehicle_mode == "Normal": - d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - elif vehicle_mode == "SwitchLeft": - d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) + 3 - elif vehicle_mode == "SwitchRight": - d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - 3 - elif vehicle_mode == "Brake": - d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - a = -1 - elif vehicle_mode == "Accel": - d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - a = 1 - elif vehicle_mode == 'Stop': - d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - a = 0 - psi = lane_map.get_lane_heading(vehicle_lane, vehicle_pos)-theta - steering = psi + np.arctan2(0.45*d, v) - steering = np.clip(steering, -0.61, 0.61) - return steering, a - def TC_simulate(self, mode: List[str], initialCondition, time_bound, lane_map:LaneMap=None)->np.ndarray: time_step = 0.05 + # Looks like this should be a global parameter; some config file should be setting this. time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) t = [round(i*time_step,10) for i in range(0,number_points)] @@ -59,7 +35,6 @@ class BallAgent(BaseAgent): init = initialCondition trace = [[0]+init] for i in range(len(t)): - steering, a = self.action_handler(mode, init, lane_map) r = ode(self.dynamic) r.set_initial_value(init).set_f_params([steering, a]) res:np.ndarray = r.integrate(r.t + time_step) @@ -69,3 +44,6 @@ class BallAgent(BaseAgent): trace.append([t[i] + time_step] + init) return np.array(trace) + +if __name__ == '__main__': + aball = BallAgent('red_ball',file_name=input_code_name) \ No newline at end of file