Skip to content
Snippets Groups Projects
multi_path_tracking.py 10 KiB
Newer Older
import numpy as np
import matplotlib.pyplot as plt
from guided_mrmp.controllers.utils import compute_path_from_wp, get_ref_trajectory
from guided_mrmp.controllers.multi_mpc import MultiMPC

class MultiPathTracker:
    def __init__(self, env, initial_positions, dynamics, target_v, T, DT, waypoints, settings, lib_2x3, lib_3x3, lib_2x5):
        """
        Initializes the PathTracker object.
        Parameters:
        - initial_positions: List of the initial positions of the robots [x, y, heading].
        - dynamics: The dynamics model of the robots.
        - target_v: The target velocity of the robots.
        - T: The time horizon for the model predictive control (MPC).
        - DT: The time step for the MPC.
        - waypoints: A list of waypoints defining the desired path for each robot.
        """
        # State of the robot [x,y, heading]
        self.env = env

        self.states = initial_positions
        self.num_robots = len(initial_positions)
        self.dynamics = dynamics
        self.T = T
        self.DT = DT
        self.target_v = target_v

        self.radius = dynamics.radius


        self.update_ref_paths = False

        # helper variable to keep track of mpc output
        # starting condition is 0,0
        self.control = np.zeros((self.num_robots, 2))

        self.K = int(T / DT)

        # For a car model 
        # Q = [20, 20, 10, 20]  # state error cost
        # Qf = [30, 30, 30, 30]  # state final error cost
        # R = [10, 10]  # input cost
        # P = [10, 10]  # input rate of change cost
        # self.mpc = MPC(VehicleModel(), T, DT, Q, Qf, R, P)

        # libraries for the discrete solver
        self.lib_2x3 = lib_2x3
        self.lib_3x3 = lib_3x3
        self.lib_2x5 = lib_2x5


        # For a circular robot (easy dynamics)
        Q = settings['model_predictive_controller']['Q']  # state error cost
        Qf = settings['model_predictive_controller']['Qf']  # state final error cost
        R = settings['model_predictive_controller']['R']  # input cost
        P = settings['model_predictive_controller']['P']  # input rate of change cost
        self.mpc = MultiMPC(self.num_robots, dynamics, T, DT, Q, Qf, R, P, settings['model_predictive_controller'], settings['environment']['circle_obstacles'])

        self.circle_obs = settings['environment']['circle_obstacles']

        # Path from waypoint interpolation
        self.paths = []
        for wp in waypoints:
            self.paths.append(compute_path_from_wp(wp[0], wp[1], 0.05))

        self.visited_points_on_guide_paths = [[]]*self.num_robots 

        
        print(f"paths = {len(self.paths)}")

        # Helper variables to keep track of the sim
        self.sim_time = 0
        self.x_history = [ [] for _ in range(self.num_robots) ]
        self.y_history = [ [] for _ in range(self.num_robots) ]
        self.v_history = [ [] for _ in range(self.num_robots) ]
        self.h_history = [ [] for _ in range(self.num_robots) ]
        self.a_history = [ [] for _ in range(self.num_robots) ]
        self.d_history = [ [] for _ in range(self.num_robots) ]
        self.optimized_trajectories_hist = [ [] for _ in range(self.num_robots) ]
        self.optimized_trajectory = None


    def trajectories_overlap(self, traj1, traj2, threshold):
        """
        Checks if two trajectories overlap. We only care about xy positions.

        Args:
            traj1 (3xn numpy array): First trajectory. First row is x, second row is y, third row is heading.
            traj2 (3xn numpy array): Second trajectory.
            threshold (float): Distance threshold to consider a collision.
        Returns:
            bool: True if trajectories overlap, False otherwise.
        """
        for i in range(traj1.shape[1]):
            for j in range(traj2.shape[1]):
                if np.linalg.norm(traj1[0:2, i] - traj2[0:2, j]) < 2*threshold:
                    return True
        return False
    

    def ego_to_global_roomba(self, state, mpc_out):
        """
        Transforms optimized trajectory XY points from ego (robot) reference
        into global (map) frame.

        Args:
            mpc_out (numpy array): Optimized trajectory points in ego reference frame.

        Returns:
            numpy array: Transformed trajectory points in global frame.
        """
        # Extract x, y, and theta from the state
        x = state[0]
        y = state[1]
        theta = state[2]

        # Rotation matrix to transform points from ego frame to global frame
        Rotm = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])

        # Initialize the trajectory array (only considering XY points)
        trajectory = mpc_out[0:2, :]

        # Apply rotation to the trajectory points
        trajectory = Rotm.dot(trajectory)
        # Translate the points to the robot's position in the global frame
        trajectory[0, :] += x
        trajectory[1, :] += y
        return trajectory
    
    def get_next_control(self, state, show_plots=False):
        # optimization loop
        # start=time.time()

        # Get Reference_traj -> inputs are in worldframe
        # 1. Get the reference trajectory for each robot
        targets = []
        for i in range(self.num_robots):
            targets.append(get_ref_trajectory(np.array(state[i]), np.array(self.paths[i]), self.target_v, self.T, self.DT, len(self.x_history[i])+1))

        # dynamycs w.r.t robot frame
        # curr_state = np.array([0, 0, self.state[2], 0])
        curr_states = np.zeros((self.num_robots, 3))
        x_mpc, u_mpc = self.mpc.step(
            curr_states,
            targets,
            self.control
        
        # only the first one is used to advance the simulation
        # self.control[:] = [u_mpc[0, 0], u_mpc[1, 0]]

        self.control = []
        for i in range(self.num_robots):
            self.control.append([u_mpc[i*2, 0], u_mpc[i*2+1, 0]])

        return x_mpc, self.control
    

    def done(self):
        for i in range(self.num_robots):
            # print(f"state = {self.states[i]}")
            # print(f"path = {self.paths[i][:, -1]}")
            if (np.sqrt((self.states[i][0] - self.paths[i][0, -1]) ** 2 + (self.states[i][1] - self.paths[i][1, -1]) ** 2) > 1):
                return False
        return True
    
    def plot_current_world_state(self):
        """
        Plot the current state of the world.
        """

        import matplotlib.pyplot as plt
        import matplotlib.cm as cm

        # Plot the current state of each robot using the most recent values from
        # x_history, y_history, and h_history
        colors = cm.rainbow(np.linspace(0, 1, self.num_robots))

        for i in range(self.num_robots):
            plot_roomba(self.x_history[i][-1], self.y_history[i][-1], self.h_history[i][-1], colors[i], False, self.radius)

        # plot the goal of each robot with solid circle
        for i in range(self.num_robots):
            x, y, theta = self.paths[i][:, -1]
            plt.plot(x, y, 'o', color=colors[i])
            circle1 = plt.Circle((x, y), self.radius, color=colors[i], fill=False)
            plt.gca().add_artist(circle1)

        # plot the ref path of each robot
        for i in range(self.num_robots):
            plt.plot(self.paths[i][0, :], self.paths[i][1, :], '--', color=colors[i])


        # set the size of the plot to be 10x10
        plt.xlim(0, 10)
        plt.ylim(0, 10)

        # force equal aspect ratio
        plt.gca().set_aspect('equal', adjustable='box')

        
        plt.show()

    def run(self, show_plots=False):
        """
        Run the path tracker algorithm.
        Parameters:
        - show_plots (bool): Flag indicating whether to show plots during the simulation. Default is False.
        Returns:
        - numpy.ndarray: Array containing the history of x, y, and h coordinates.
        """

        # Add the initial state to the histories
        self.states = np.array(self.states)
        for i in range(self.num_robots):
            self.x_history[i].append(self.states[i, 0])
            self.y_history[i].append(self.states[i, 1])
            self.h_history[i].append(self.states[i, 2])
        if show_plots: self.plot_sim()

        self.plot_current_world_state()
        
        while 1:
            # check if all robots have reached their goal
            if self.done():
                print("Success! Goal Reached")
                return np.asarray([self.x_history, self.y_history, self.h_history])
            
            # plot the current state of the robots
            self.plot_current_world_state()
            
            # get the next control for all robots
            x_mpc, controls = self.get_next_control(self.states)

            next_states = []
            for i in range(self.num_robots):
                next_states.append(self.dynamics.next_state(self.states[i], controls[i], self.DT))

            self.states = next_states

            self.states = np.array(self.states)
            for i in range(self.num_robots):
                self.x_history[i].append(self.states[i, 0])
                self.y_history[i].append(self.states[i, 1])
                self.h_history[i].append(self.states[i, 2])
            
            if self.update_ref_paths:
                self.update_reference_paths()
                self.update_ref_paths = False            

            # use the optimizer output to preview the predicted state trajectory
            # self.optimized_trajectory = self.ego_to_global(x_mpc.value)
            if show_plots: self.optimized_trajectory = self.ego_to_global_roomba(x_mpc)
            if show_plots: self.plot_sim()

def plot_roomba(x, y, yaw, color, fill, radius):
    """

    Args:
        x ():
        y ():
        yaw ():
    """
    fig = plt.gcf()
    ax = fig.gca()
    if fill: alpha = .3
    else: alpha = 1
    circle = plt.Circle((x, y), radius, color=color, fill=fill, alpha=alpha)
    ax.add_patch(circle)

    # Plot direction marker
    dx = 1 * np.cos(yaw)
    dy = 1 * np.sin(yaw)
    ax.arrow(x, y, dx, dy, head_width=0.1, head_length=0.1, fc='r', ec='r')



# if __name__ == "__main__":
#     main()