Skip to content
Snippets Groups Projects
db_guided_mrmp.py 19.93 KiB
"""
Database-guided multi-robot motion planning

This module is essentially one big path tracking algorithm. 
It uses MPC to path track each of the robots, while looking ahead to 
identify and resolve conflicts.
"""
import time as t
import pygame
from shapely.geometry import Polygon, Point

from guided_mrmp.utils import Env
from guided_mrmp.utils.helpers import *


from guided_mrmp.controllers.utils import get_ref_trajectory, compute_path_from_wp

from guided_mrmp.planners import RRT
from guided_mrmp.planners import RRTStar
from guided_mrmp.controllers.mpc import MPC
from guided_mrmp.controllers.multi_mpc import MultiMPC 
# from guided_mrmp.controllers.multi_path_tracking import MultiPathTracker

def plan_decoupled_path(env, start, goal, solver="RRT*", 
                        step_length=20, goal_sample_rate=.5, num_samples=500000, r=10):
    
    """
    Plan decoupled path from a given start to a given goal, using a single-agent solver.

    inputs:
        - start (tuple): (x,y) location of start 
        - goal (tuple): (x,y) location of goal 
        - solver (string): Name of single-agent solver to be used
        - step_length (float): 
        - goal_sample_rate (float):
        - num_samples (int):
        - r (float):
    output:
        - path (list): list of nodes in path 
    """
    if solver == "RRT":
        rrt = RRT(env, start, goal, step_length, goal_sample_rate, num_samples)
        path,tree = rrt.run()
    elif solver == "RRT*":
        rrtstar = RRTStar(env, start, goal, step_length, goal_sample_rate, num_samples,r)
        path,tree = rrtstar.run()
    else:
        print(f"Solver {solver} is not yet implemented. Choose something else.")
        return None

    return list(reversed(path)), tree

class GuidedMRMP:
    def __init__(self, env, robots, dynamics_models, T, DT, libs,settings):
        """
        inputs:
            - robots (list): list of Robot class objects
            - env (Env): the environment
        """
        self.robots = robots
        self.dynamics_models = dynamics_models
        self.env = env
        self.guide_paths = [[]]*len(robots)
        self.settings = settings

        self.T = T
        self.DT = DT
        self.K = int(T / DT)

        # for MPC
        self.Q = self.settings['model_predictive_controller']['Q']  # state error cost
        self.Qf = self.settings['model_predictive_controller']['Qf']  # state final error cost
        self.R = self.settings['model_predictive_controller']['R']  # input cost
        self.P = self.settings['model_predictive_controller']['P']   # input rate of change cost

        # for conflict resolution
        self.obstacle_weight = self.settings['multi_robot_traj_opt']['obstacle_weight']
        self.rob_dist_weight = self.settings['multi_robot_traj_opt']['rob_dist_weight']
        self.time_weight = self.settings['multi_robot_traj_opt']['time_weight']
        self.control_weight = self.settings['multi_robot_traj_opt']['control_weight']
        self.goal_weight = self.settings['multi_robot_traj_opt']['goal_weight']


    
        for idx,r in enumerate(self.robots):
            xs = []
            ys = []
            for node in r.rrtpath:
                xs.append(node[0])
                ys.append(node[1])

            waypoints = [xs,ys]

            self.guide_paths[idx] = compute_path_from_wp(waypoints[0], waypoints[1], .05)

        self.scaling_factor = self.settings['simulator']['scaling_factor']

        self.libs = libs

    def ego_to_global(self, robot, mpc_out):
        """
        transforms optimized trajectory XY points from ego (robot) reference
        into global (map) frame

        Args:
            mpc_out ():
        """
        # Extract x, y, and theta from the state
        x = robot.current_position[0]
        y = robot.current_position[1]
        theta = robot.current_position[2]

        # Rotation matrix to transform points from ego frame to global frame
        Rotm = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])

        # Initialize the trajectory array (only considering XY points)
        trajectory = mpc_out[0:2, :]

        # Apply rotation to the trajectory points
        trajectory = Rotm.dot(trajectory)

        # Translate the points to the robot's position in the global frame
        trajectory[0, :] += x
        trajectory[1, :] += y

        return trajectory
    
    def find_all_conflicts(self, desired_trajs, dt):
        """
        Check for any conflicts that will occur between robots in the next time horizon
        """
        conflicts = []

        for r1_idx, r1 in enumerate(self.robots):
            # control = desired_controls[r1_idx]
            traj1 = desired_trajs[r1_idx]
            traj1 = list(zip(traj1[0],traj1[1]))
            
            for r2_idx, r2 in enumerate(self.robots[r1_idx+1:]):
                # control = desired_controls[r2_idx]
                traj2 = desired_trajs[r2_idx+r1_idx+1]
                traj2 = list(zip(traj2[0],traj2[1]))

                for p1, p2 in zip(traj1, traj2):
                    circ1 = Point(p1[0],p1[1])
                    circ1 = circ1.buffer(r1.radius)
                    circ2 = Point(p2[0],p2[1])
                    circ2 = circ2.buffer(r2.radius)
                    # check if the robots overlap
                    if circ1.intersects(circ2): 
                        if (r1,r2) not in conflicts and (r2,r1) not in conflicts:
                            # print(f"conflict between {r1.label} and {r2.label}")
                            conflicts.append((r1,r2))
                
                

        return conflicts

    def get_next_controls_and_trajectories_parallel(self,screen):
        """
        Get the next control for each robot.

        """
        from joblib import Parallel, delayed
        
        next_controls = []
        next_trajs = []

        def process_robot(idx, r):
            state = r.current_position
            path = self.guide_paths[idx]

            # Get Reference_traj -> inputs are in worldframe
            target_traj = get_ref_trajectory(np.array(state), 
                                            np.array(path), 
                                            r.target_v, 
                                            self.T, 
                                            self.DT)
            
            mpc = MPC(self.dynamics_models[idx], self.T, self.DT, self.Q, self.Qf, self.R, self.P, self.settings['model_predictive_controller'])

            if r.last_mpc_trajectory is not None:
                # dynamics w.r.t robot frame
                initial_guess = {'X': r.last_mpc_trajectory, 'U': r.last_mpc_controls}
                curr_state = np.array([0, 0, 0])
                x_mpc, u_mpc = mpc.step(curr_state, target_traj, np.array(r.control), initial_guess)

            else:
                # dynamics w.r.t robot frame
                curr_state = np.array([0, 0, 0])
                x_mpc, u_mpc = mpc.step(curr_state, target_traj, np.array(r.control))

            # only the first one is used to advance the simulation
            control = [u_mpc[0, :], u_mpc[1, :]]

            r.last_mpc_trajectory = x_mpc
            r.last_mpc_controls = u_mpc
            
            return np.asarray(control), self.ego_to_global(r, x_mpc)

        with Parallel(n_jobs=12) as parallel:
            results = parallel(delayed(process_robot)(idx, r) for idx, r in enumerate(self.robots))

        next_controls = [r[0] for r in results]
        next_trajs = [r[1] for r in results]


        return next_controls, next_trajs
        
    def get_next_controls_and_trajectories(self,screen):
        """
        Get the next control for each robot.
        """

        next_controls = []
        next_trajs = []
        for idx, r in enumerate(self.robots):

            state = r.current_position
            path = self.guide_paths[idx]

            # Get Reference_traj -> inputs are in worldframe
            target_traj = get_ref_trajectory(np.array(state), 
                                             np.array(path), 
                                             r.target_v, 
                                             self.T, 
                                             self.DT)
            
            mpc = MPC(self.dynamics_models[idx], self.T, self.DT, self.Q, self.Qf, self.R, self.P, self.settings['model_predictive_controller'])

            # dynamics w.r.t robot frame
            curr_state = np.array([0, 0, 0])
            x_mpc, u_mpc = mpc.step(curr_state, target_traj, np.array(r.control))

        
            self.add_vis_target_traj(screen, self.scaling_factor, r, x_mpc)
            
            # only the first one is used to advance the simulation
            control = [u_mpc[0, :], u_mpc[1, :]]
            
            next_controls.append(np.asarray(control))

            print(f"x_mpc = {x_mpc}")
            next_trajs.append(self.ego_to_global(r, x_mpc))

        # use matplotlib to plot the trajectories
        # for r, traj in zip(self.robots, next_trajs):
        #     plt.plot(traj[0], traj[1], label=r.label)
        # plt.legend()
        # plt.show()
        return next_controls, next_trajs

    def get_next_controls_and_trajectories_multi_robot(self,screen):
        """
        Get the next control for each robot, but use MPC that considers all robots at once.
        """

        multiMPC = MultiMPC(len(self.robots), self.dynamics_models[0], self.T, self.DT, self.Q, self.Qf, self.R, self.P, self.settings['model_predictive_controller'])
        
        # 1. Get the reference trajectory for each robot
        targets = []
        state = []
        for r in self.robots:
            state.append(r.current_position)
        for i in range(len(self.robots)):
            print(f"state i = {state[i]}")
            # print(f"path i = {self.paths[i]}")
            targets.append(get_ref_trajectory(np.array(state[i]), np.array(self.guide_paths[i]), self.robots[i].target_v, self.T, self.DT))


        # dynamycs w.r.t robot frame
        # curr_state = np.array([0, 0, self.state[2], 0])
        curr_states = np.zeros((len(self.robots), 3))
        controls = []
        for r in self.robots:
            controls.append(r.control)
        x_mpc, u_mpc = multiMPC.step(
            curr_states,
            targets,
            controls
        )
        
        # only the first one is used to advance the simulation
        # self.control[:] = [u_mpc[0, 0], u_mpc[1, 0]]

        next_controls = []
        for i in range(len(self.robots)):
            next_controls.append([u_mpc[i*2, 0], u_mpc[i*2+1, 0]])

        next_trajs = []
        for i in range(len(self.robots)):
            print([x_mpc[i*3,:], x_mpc[i*3+1,:], x_mpc[i*3+2,:]])
            next_trajs.append(self.ego_to_global(self.robots[i], np.array([x_mpc[i*3,:], x_mpc[i*3+1,:], x_mpc[i*3+2,:]])))


        return next_controls, next_trajs


    def advance(self, screen, state, time, dt=0.1):
        """
        Advance the simulation by one timestep.

        inputs:
            - state (list): list of tuples, where each tuple is the state of a robot
            - time (float): current time
            - t (float): time step
        """

        # update the guide paths of each robot
        # for idx, r in enumerate(self.robots):
        #     start = r.current_position
        #     goal = r.goal
        #     xs = []
        #     ys = []
        #     rrtpath, tree = plan_decoupled_path(self.env, (start[0],start[1]), (goal[0],goal[1]))
        #     for node in rrtpath:
        #         xs.append(node[0])
        #         ys.append(node[1])

        #     waypoints = [xs,ys]

        #     self.guide_paths[idx] = compute_path_from_wp(waypoints[0], waypoints[1], .05)


        # get the next control for each robot
        # if self.settings['model_predictive_controller']['parallelize']:
        if False:
            next_desired_controls, trajectories = self.get_next_controls_and_trajectories_parallel(screen)
        else:
            next_desired_controls, trajectories = self.get_next_controls_and_trajectories_multi_robot(screen)

        # find all the conflicts at the next time horizon
        conflicts = self.find_all_conflicts(trajectories, dt)

        print(f"conflicts = {conflicts}")

        if True:
            # no conflicts, return the desired controls
            controls = []

            for control in next_desired_controls:
                # print(control)
                # next_controls = [control[0][0], control[1][0]]
                controls.append(np.asarray(control))

            for r, next_control in zip(self.robots, controls):
                r.control = next_control
            return False, controls

        # resolve the conflicts using the database
        rad = self.robots[0].radius

        # resolver = TrajOptResolver(len(starts), rad, starts, goals, self.env.circle_obs, self.env.rect_obs, 
        #                            self.rob_dist_weight, self.obstacle_weight, self.control_weight, 
        #                            self.time_weight, self.goal_weight, conflicts, self.robots)
        # resolver = TrajOptDBResolver(cell_size = rad*2.5, 
        #                              grid_size = 5, 
        #                              all_conflicts = conflicts, 
        #                              all_robots=self.robots, 
        #                              trajs=trajectories,
        #                              robot_radius=rad, 
        #                              env=self.env, 
        #                              rob_dist_weight=self.rob_dist_weight, 
        #                              obs_dist_weight=self.obstacle_weight, 
        #                              control_weight=self.control_weight, 
        #                              time_weight=self.time_weight, 
        #                              goal_weight=self.goal_weight, 
        #                              lib_2x3=self.libs[0], 
        #                              lib_3x3=self.libs[1], 
        #                              lib_2x5=self.libs[2])


        starts = [r.current_position for r in self.robots]
        print(f"starts = {starts}")
        print(f"trajs = {trajectories}")
        resolver = MultiPathTracker(starts, self.dynamics_models[0], 3.0, self.T, self.DT, trajectories, self.settings)

        # get the current positiosn
        # starts = [r.current_position for r in self.robots]
        # goals = [r.goal for r in self.robots]
        
        # dump all the data need for trajoptDBresolver to a yaml file for debugging
        # data = {
        #     "conflicts": conflicts,
        #     "dt": dt,
        #     "env": self.env,
        #     "next_desired_controls": next_desired_controls,
        #     "robots": self.robots,
        #     "trajectories": trajectories,
        #     "starts": starts,
        #     "goals": goals,
        #     "rad": rad,
        #     "rob_dist_weight": self.rob_dist_weight,
        #     "obstacle_weight": self.obstacle_weight,
        #     "control_weight": self.control_weight,
        #     "time_weight": self.time_weight,
        #     "goal_weight": self.goal_weight
        # }

        # import yaml
        # with open("guided_mrmp/tests/db_opt_data1.yaml", "w") as file:
        #     yaml.dump(data, file)

        # get the local controls
        # next_desired_controls, trajs = resolver.get_local_controls(next_desired_controls)
        resolver.run()

        vs = resolver.v_history
        omegas = resolver.a_history

        next_desired_controls = []
        for v, omega in zip(vs, omegas):
            next_desired_controls.append([v[1], omega[1]])



        # # get the goals that were used by the resolver
        # starts, goals = resolver.get_temp_starts_and_goals()
        # # get the continuous space values of these goals
        # continuous_subgoals = []
        # for g in goals:
        #     continuous_subgoals.append(resolver.get_grid_cell_location(g[0], g[1]))

        # reroute the robots' guide paths to include the new subgoals in the guidepaths
        # update the guide paths of each robot
        # for idx, r in enumerate(self.robots):
        #     subgoal = continuous_subgoals[idx]
        #     start = r.current_position
        #     goal = r.goal
        #     xs = []
        #     ys = []
        #     # rrtpath, tree = plan_decoupled_path(self.env, (subgoal[0],subgoal[1]), (goal[0],goal[1]))
        #     rrtstar = RRTStar(self.env, (subgoal[0],subgoal[1]), (goal[0],goal[1]), step_len=20, goal_sample_rate=.5, iter_max=500000, r=80, sampled_vertices=r.tree)
        #     rrtpath,tree = rrtstar.run()
        #     rrtpath = list(reversed(rrtpath))
        #     print(f"rrtpath = {rrtpath}")
        #     for node in rrtpath:
        #         print(f"node = {node}")
        #         xs.append(node[0])
        #         ys.append(node[1])

        #     xs.insert(0,start[0])
        #     ys.insert(0,start[1])

        #     waypoints = [xs,ys]

        #     self.guide_paths[idx] = compute_path_from_wp(waypoints[0], waypoints[1], .05)

        
        # self.add_vis_conflict_free_traj(screen, self.scaling_factor, trajs)
            

        # return the valid controls
        for r, next_control in zip(self.robots, next_desired_controls):
            r.control = next_control
        
        return True, next_desired_controls


    def add_vis_guide_paths(self, screen, scaling_factor):
        """
        Add the visualization to the screen.
        """
        
        for r in self.robots:
            # path = self.guide_paths[r.label]
            # xs = path[0]
            # ys = path[1]
            # # for node in path:
            # #     pygame.draw.circle(screen, r.color, (int(node[0]), int(node[1])), 2)
            # for i in range(len(xs)-1):
            #     pygame.draw.line(screen, r.color, (xs[i]*scaling_factor,ys[i]*scaling_factor), 
            #                      (xs[i+1]*scaling_factor,ys[i+1]*scaling_factor), 2)
                
            tree = r.tree
            for node in tree:
                # draw a line from the node to its parent
                print(f"node = {node}")
                if node.parent is not None:
                    print(f"parent = {node.parent}")
                    pygame.draw.line(screen, r.color, (node.current[0]*scaling_factor,node.current[1]*scaling_factor), 
                                     (node.parent[0]*scaling_factor,node.parent[1]*scaling_factor), 2)

    def add_vis_grid(self, screen, grid_size, top_left, cell_size):
        """
        Add the visualization to the screen.
        """
         # draw the grid
        line_len = grid_size*cell_size
        for i in range(grid_size+1):
            xs = [top_left[0] + i*cell_size,top_left[0] + i*cell_size]
            ys = [top_left[1], top_left[1]-line_len]
            pygame.draw.line(screen, (0,0,0), (xs[0],ys[0]), (xs[1],ys[1]), 2)

            xs = [top_left[0], top_left[0]+line_len]
            ys = [top_left[1] - i*cell_size,top_left[1] - i*cell_size]
            
            pygame.draw.line(screen, (0,0,0), (xs[0],ys[0]), (xs[1],ys[1]), 2)

    def add_vis_target_traj(self,screen, scaling_factor, robot, traj):
        """
        Add the visualization to the screen.
        """
        # scaling_factor = 1
        traj = self.ego_to_global(robot, traj)
        for i in range(len(traj[0])-1):
            x = traj[0,i]*scaling_factor
            y = traj[1,i]*scaling_factor
            next_x = traj[0,i+1]*scaling_factor
            next_y = traj[1,i+1]*scaling_factor
            pygame.draw.line(screen, (255,0,0), (x,y), (next_x,next_y), 3)

    def add_vis_conflict_free_traj(self,screen, scaling_factor, trajs):
        """
        Add the visualization to the screen.
        """
        # traj = self.ego_to_global(robot, traj)
        for traj in trajs:
            # print(traj)
            for i in range(len(traj[0])-1):
                x = traj[0][i]
                y = traj[1][i]
                next_x = traj[0][i+1]
                next_y = traj[1][i+1]
                pygame.draw.line(screen, (0,255,0), (x*scaling_factor,y*scaling_factor), 
                                (next_x*scaling_factor,next_y*scaling_factor), 2)