From fce350a63803fdb68fe17fea860a18cdf8d106fb Mon Sep 17 00:00:00 2001
From: rachelmoan <moanrachel516@gmail.com>
Date: Thu, 6 Feb 2025 12:27:03 -0600
Subject: [PATCH] Use one MPC for both single and multirobot

---
 guided_mrmp/controllers/multi_mpc.py          |   3 -
 .../controllers/multi_path_tracking.py        |   2 -
 .../controllers/multi_path_tracking_db.py     | 132 ++++++++++++------
 guided_mrmp/controllers/utils.py              |  88 ++++++++++--
 4 files changed, 169 insertions(+), 56 deletions(-)

diff --git a/guided_mrmp/controllers/multi_mpc.py b/guided_mrmp/controllers/multi_mpc.py
index f171cdc..fee4976 100644
--- a/guided_mrmp/controllers/multi_mpc.py
+++ b/guided_mrmp/controllers/multi_mpc.py
@@ -214,8 +214,6 @@ class MultiMPC:
         """
         As, Bs, Cs = [], [], []
         for i in range(self.num_robots):
-            # print(f"initial_state[i] = {initial_state[i]}")
-            # print(f"prev_cmd[i] = {prev_cmd[i]}")
             A, B, C = self.robot_model.linearize(initial_state[i], prev_cmd[i], self.dt)
             As.append(A)
             Bs.append(B)
@@ -223,7 +221,6 @@ class MultiMPC:
 
         solver_options = {'ipopt.print_level': self.print_level, 
                           'print_time': self.print_time, 
-                        #   'ipopt.tol': 1e-3,
                           'ipopt.acceptable_tol': self.acceptable_tol, 
                           'ipopt.acceptable_iter': self.acceptable_iter}
 
diff --git a/guided_mrmp/controllers/multi_path_tracking.py b/guided_mrmp/controllers/multi_path_tracking.py
index 37d7de6..cfadef7 100644
--- a/guided_mrmp/controllers/multi_path_tracking.py
+++ b/guided_mrmp/controllers/multi_path_tracking.py
@@ -44,8 +44,6 @@ class MultiPathTracker:
 
         self.coupled_mpc = MultiMPC(self.num_robots, dynamics, T, DT, settings, env.circle_obs, env.rect_obs)
 
-        self.mpc = MPC(dynamics, T, DT, settings, env.circle_obs, env.rect_obs)
-
         self.circle_obs = env.circle_obs
         self.rect_obs = env.rect_obs
 
diff --git a/guided_mrmp/controllers/multi_path_tracking_db.py b/guided_mrmp/controllers/multi_path_tracking_db.py
index be1a0b8..982a310 100644
--- a/guided_mrmp/controllers/multi_path_tracking_db.py
+++ b/guided_mrmp/controllers/multi_path_tracking_db.py
@@ -12,7 +12,7 @@ from shapely.geometry import Point
 from shapely.geometry import Polygon
 from guided_mrmp.utils.helpers import plan_decoupled_path
 
-from guided_mrmp.controllers.mpc import MPC
+from guided_mrmp.controllers.multi_mpc import MultiMPC
 from guided_mrmp.controllers.place_grid import place_grid
 
 class DiscreteRobot:
@@ -69,6 +69,15 @@ class MultiPathTrackerDB(MultiPathTracker):
 
         return temp_starts, temp_goals
     
+    def get_subgoals(self, state, robots_in_conflict):
+        subgoals = []
+        for idx in robots_in_conflict:
+            traj = self.ego_to_global_roomba(state[idx], self.trajs[idx])
+            x = traj[0][-1]
+            y = traj[1][-1]
+            subgoals.append((x, y))
+        return subgoals     
+
     def create_discrete_robots(self, starts, goals):
         discrete_robots = []
         for i in range(len(starts)):
@@ -274,19 +283,25 @@ class MultiPathTrackerDB(MultiPathTracker):
                                                            self.target_v, 
                                                            self.T, 
                                                            self.DT, 
-                                                           [])
+                                                           self.visited_points_on_guide_paths[i])
             
             self.visited_points_on_guide_paths[i] = visited_guide_points
 
             targets.append(ref)
         
+            mpc = MultiMPC(1, # num robots
+                           self.dynamics, 
+                           self.T, 
+                           self.DT, 
+                           self.settings, 
+                           self.env.circle_obs, 
+                           self.env.rect_obs)
             
-
-            curr_state = np.array([0,0,0])
-            x_mpc, u_mpc = self.mpc.step(
+            curr_state = np.zeros((1, 3))
+            x_mpc, u_mpc = mpc.step(
                 curr_state,
-                ref,
-                self.control[i]
+                [ref],
+                [self.control[i]]
             )
 
             x_mpc_global = self.ego_to_global_roomba(state[i], x_mpc)
@@ -314,15 +329,21 @@ class MultiPathTrackerDB(MultiPathTracker):
                                             [])
             
             
-            mpc = MPC(self.dynamics, self.T, self.DT, self.settings)
+            mpc = MultiMPC(1, # num robots
+                           self.dynamics, 
+                           self.T, 
+                           self.DT, 
+                           self.settings, 
+                           self.env.circle_obs, 
+                           self.env.rect_obs)
             
-            
-            curr_state = np.array([0,0,0])
+            curr_state = np.zeros((1, 3))
             x_mpc, u_mpc = mpc.step(
-                    curr_state,
-                    ref,
-                    current_control
-                )
+                curr_state,
+                [ref],
+                [self.control[idx]]
+            )
+
                 
             # only the first one is used to advance the simulation
             control = [u_mpc[0, 0], u_mpc[1, 0]]
@@ -379,15 +400,21 @@ class MultiPathTrackerDB(MultiPathTracker):
 
             
             # Put down a local grid
-            self.cell_size = self.radius*3
+            self.cell_size = self.radius*2
             self.grid_size = 5
 
-            if diff < 4*self.cell_size:
+            if diff < 5*self.cell_size:
+                circle_obs = []
+                for obs in self.circle_obs:
+                    circle_obs.append(('c',obs[0], obs[1], obs[2]))
+
+                subgoals = self.get_subgoals(state, c)
+
                 grid_origin, centers = place_grid(robot_positions, 
-                                                  cell_size=self.radius*3, 
+                                                  cell_size=self.radius*2, 
                                                   grid_size=5, 
-                                                  subgoals=[],
-                                                  obstacles=self.circle_obs)
+                                                  subgoals=subgoals,
+                                                  obstacles=circle_obs)
                 grid_obstacle_map = self.get_obstacle_map(grid_origin, self.grid_size, self.radius) 
 
                 # Solve a discrete version of the problem 
@@ -409,6 +436,7 @@ class MultiPathTrackerDB(MultiPathTracker):
 
                     if show_plots: self.draw_grid_solution(continuous_soln, state, grid_origin, grid_obstacle_map, c, round(time, 2))
                     
+                    # for each robot in conflict, reroute its reference trajectory to match the grid solution
                     # for each robot in conflict, reroute its reference trajectory to match the grid solution
                     import copy
                     old_paths = copy.deepcopy(self.paths)
@@ -418,7 +446,6 @@ class MultiPathTrackerDB(MultiPathTracker):
 
                         # plan from the last point of the ref path to the robot's goal
                         # plan an RRT path from the current state to the goal
-
                         start = (new_ref[:, -1][0], new_ref[:, -1][1])
                         goal = (old_paths[i][:, -1][0], old_paths[i][:, -1][1])
 
@@ -441,11 +468,17 @@ class MultiPathTrackerDB(MultiPathTracker):
 
                             wp = [xs,ys]
 
-                            # Path from waypoint interpolation
-                            self.paths[i] = compute_path_from_wp(wp[0], wp[1], 0.05)
-                        else:
-                            print("RRT* failed to find a path")
-                            self.paths[i] = old_paths[i]
+                        # Path from waypoint interpolation
+                        path = compute_path_from_wp(wp[0], wp[1], 0.05)
+
+                        # combine the path with new_ref
+                        new_ref_x = np.concatenate((new_ref[0, :], path[0]))
+                        new_ref_y = np.concatenate((new_ref[1, :], path[1]))
+                        new_ref_theta = np.concatenate((new_ref[2, :], path[2]))
+
+                        self.paths[i] = np.array([new_ref_x, new_ref_y, new_ref_theta])
+
+                        self.visited_points_on_guide_paths[i] = []
 
                     for i in c:
                         ref, visited_guide_points = get_ref_trajectory(np.array(state[i]), 
@@ -456,21 +489,31 @@ class MultiPathTrackerDB(MultiPathTracker):
                                                             self.visited_points_on_guide_paths[i])
                 
                         self.visited_points_on_guide_paths[i] = visited_guide_points
-                        
-                        curr_state = np.array([0,0,0])
-                        
-                        x_mpc, u_mpc = self.mpc.step(
-                            curr_state,
-                            ref,
-                            self.control[i]
-                        )
-
-                        u_next[i] = [u_mpc[0, 0], u_mpc[1, 0]]
-
-                        x_mpc_global = self.ego_to_global_roomba(state[i], x_mpc)
-                        x_next[i] = x_mpc_global
-
                         self.trajs[i] = ref
+                    
+
+                    # use MPC to track the new reference trajectories
+                    # include all the robots that were in conflict in the MPC problem
+                    mpc = MultiMPC(len(c), # num robots
+                        self.dynamics, 
+                        self.T, 
+                        self.DT, 
+                        self.settings, 
+                        self.env.circle_obs, 
+                        self.env.rect_obs)
+
+                    curr_states = np.zeros((len(c), 3))
+                    these_trajs = [self.trajs[i] for i in c]
+                    these_controls = [self.control[i] for i in c]
+                    x_mpc, u_mpc = mpc.step(
+                        curr_states,
+                        these_trajs,
+                        these_controls
+                    )
+
+                    for i, r in enumerate(c):
+                        u_next[r] = [u_mpc[i*2, 0], u_mpc[i*2+1, 0]]
+                        x_next[r] = [x_mpc[i*3, 1], x_mpc[i*3+1, 1], x_mpc[i*3+2, 1]]
 
                 else:
                     if waiting: 
@@ -482,9 +525,16 @@ class MultiPathTrackerDB(MultiPathTracker):
 
                     else:
                         print("Using coupled solver to resolve conflict")
-                        # dynamycs w.r.t robot frame
+                        mpc = MultiMPC(self.num_robots, # num robots
+                           self.dynamics, 
+                           self.T, 
+                           self.DT, 
+                           self.settings, 
+                           self.env.circle_obs, 
+                           self.env.rect_obs)
+
                         curr_states = np.zeros((self.num_robots, 3))
-                        x_mpc, u_mpc = self.coupled_mpc.step(
+                        x_mpc, u_mpc = mpc.step(
                             curr_states,
                             self.trajs,
                             self.control
diff --git a/guided_mrmp/controllers/utils.py b/guided_mrmp/controllers/utils.py
index 29d541c..dfe4aad 100644
--- a/guided_mrmp/controllers/utils.py
+++ b/guided_mrmp/controllers/utils.py
@@ -58,12 +58,12 @@ def get_nn_idx(state, path, visited=[]):
     distances = np.linalg.norm(path[:2] - state[:2].reshape(2, 1), axis=0)
     
     # Set the distance to infinity for visited points
-    for point in visited:
-        point = np.array(point)
-        # print(f"point = {point}")
-        # print(f"path[:2] = {path[:2]}")
-        # Set the distance to infinity for visited points
-        distances = np.where(np.linalg.norm(path[:2] - point.reshape(2, 1), axis=0) < 1e-3, np.inf, distances)
+    # for point in visited:
+    #     point = np.array(point)
+    #     # print(f"point = {point}")
+    #     # print(f"path[:2] = {path[:2]}")
+    #     # Set the distance to infinity for visited points
+    #     distances = np.where(np.linalg.norm(path[:2] - point.reshape(2, 1), axis=0) < 1e-3, np.inf, distances)
     
     return np.argmin(distances)
 
@@ -101,10 +101,27 @@ def get_ref_trajectory(state, path, target_v, T, DT, path_visited_points=[]):
 
     xref = np.zeros((3, K))  # Reference trajectory for [x, y, theta]
 
-    # find the nearest path point to the current state
-    ind = get_nn_idx(state, path, path_visited_points)
+    path_distances = [0]
+    for i in range(1, len(path)):
+        dist = np.linalg.norm(np.array(path[i]) - np.array(path[i-1]))
+        path_distances.append(path_distances[-1] + dist)
+    
+    # Find the last visited point
+    last_visited_idx = 0 if path_visited_points == [] else path_visited_points[-1]
+    
+    # Find the spatially closest point after the last visited point
+    next_ind = last_visited_idx + 2
 
-    path_visited_points.append([path[0, ind], path[1, ind]])
+    ind = next_ind
+    # ind = get_nn_idx(state, path, path_visited_points)
+    # min_dist = float('inf')
+    # for i in range(last_visited_idx+2, len(path)):
+    #     dist = np.linalg.norm(np.array(path[i][:2]) - np.array(state[:2]))
+    #     if dist < min_dist:
+    #         min_dist = dist
+    #         ind = i 
+
+    path_visited_points.append(ind)
 
     # calculate the cumulative distance along the path
     cdist = np.append([0.0], np.cumsum(np.hypot(np.diff(path[0, :]), np.diff(path[1, :]))))
@@ -130,4 +147,55 @@ def get_ref_trajectory(state, path, target_v, T, DT, path_visited_points=[]):
     xref[2, :] = (xref[2, :] + np.pi) % (2.0 * np.pi) - np.pi
     xref[2, :] = fix_angle_reference(xref[2, :], xref[2, 0])
 
-    return xref, path_visited_points
\ No newline at end of file
+    return xref, path_visited_points
+
+
+# def get_ref_trajectory(state, path, target_v, T, DT, path_visited_points=[]):
+#     """
+#     Generates a reference trajectory for the Roomba.
+
+#     Args:
+#         state (array-like): Current state [x, y, theta]
+#         path (ndarray): Path points [x, y, theta] in the global frame
+#         path_visited_points (array-like): Visited path points [[x, y], [x, y], ...]
+#         target_v (float): Desired speed
+#         T (float): Control horizon duration
+#         DT (float): Control horizon time-step
+
+#     Returns:
+#         ndarray: Reference trajectory [x_k, y_k, theta_k] in the ego frame
+#     """
+#     K = int(T / DT)
+
+#     xref = np.zeros((3, K))  # Reference trajectory for [x, y, theta]
+
+#     # find the nearest path point to the current state
+#     ind = get_nn_idx(state, path, path_visited_points)
+
+#     path_visited_points.append([path[0, ind], path[1, ind]])
+
+#     # calculate the cumulative distance along the path
+#     cdist = np.append([0.0], np.cumsum(np.hypot(np.diff(path[0, :]), np.diff(path[1, :]))))
+#     cdist = np.clip(cdist, cdist[0], cdist[-1])
+
+#     # determine where we want the robot to be at each time step 
+#     start_dist = cdist[ind]
+#     interp_points = [d * DT * target_v + start_dist for d in range(1, K + 1)]
+
+#     # interpolate between these points to get the reference trajectory
+#     xref[0, :] = np.interp(interp_points, cdist, path[0, :])
+#     xref[1, :] = np.interp(interp_points, cdist, path[1, :])
+#     xref[2, :] = np.interp(interp_points, cdist, path[2, :])
+    
+#     # Transform to ego frame
+#     dx = xref[0, :] - state[0]
+#     dy = xref[1, :] - state[1]
+#     xref[0, :] = dx * np.cos(-state[2]) - dy * np.sin(-state[2])  # X
+#     xref[1, :] = dy * np.cos(-state[2]) + dx * np.sin(-state[2])  # Y
+#     xref[2, :] = path[2, ind] - state[2]  # Theta
+
+#     # Normalize the angles
+#     xref[2, :] = (xref[2, :] + np.pi) % (2.0 * np.pi) - np.pi
+#     xref[2, :] = fix_angle_reference(xref[2, :], xref[2, 0])
+
+#     return xref, path_visited_points
-- 
GitLab