From 202c2569ebf810db9e39c1cd6822c57b034b3ecf Mon Sep 17 00:00:00 2001
From: rachelmoan <moanrachel516@gmail.com>
Date: Wed, 13 Nov 2024 12:50:20 -0600
Subject: [PATCH] try setting dual variables

---
 .../conflict_resolvers/traj_opt_resolver.py       | 15 +++++++--------
 guided_mrmp/optimizer.py                          | 13 +++++++++++--
 guided_mrmp/tests/test_traj_opt.py                |  5 +++--
 3 files changed, 21 insertions(+), 12 deletions(-)

diff --git a/guided_mrmp/conflict_resolvers/traj_opt_resolver.py b/guided_mrmp/conflict_resolvers/traj_opt_resolver.py
index 10dfbb4..0160d00 100644
--- a/guided_mrmp/conflict_resolvers/traj_opt_resolver.py
+++ b/guided_mrmp/conflict_resolvers/traj_opt_resolver.py
@@ -169,12 +169,12 @@ class TrajOptResolver():
 
         return {'opti':opti, 'X':X, 'U':U, 'T':T, 'cost':cost, 'robot_cost':robot_cost, 'obs_cost':obs_cost, 'time_cost':time_cost, 'control_cost':control_cost, 'goal_cost':goal_cost}
 
-    def solve_optimization_problem(self, problem, initial_guesses=None, solver_options=None):
+    def solve_optimization_problem(self, problem, initial_guesses=None, solver_options=None, prior_solution=None):
         opt = Optimizer(problem)
-        results = opt.solve_optimization_problem(initial_guesses, solver_options)
-        return results
+        results,sol = opt.solve_optimization_problem(initial_guesses, solver_options, prior_solution)
+        return results,sol
     
-    def solve(self, N, x_range, y_range, initial_guesses=None, solver_options=None):
+    def solve(self, N, x_range, y_range, initial_guesses=None, solver_options=None, prior_solution=None):
         """
         Setup and solve a multi-robot traj opt problem
 
@@ -186,9 +186,7 @@ class TrajOptResolver():
         problem = self.problem_setup(N, x_range, y_range)
 
 
-        
-
-        results = self.solve_optimization_problem(problem, initial_guesses, solver_options)
+        results,old_sol = self.solve_optimization_problem(problem, initial_guesses, solver_options, prior_solution)
 
         if results['status'] == 'failed':
             return None, None, None, None, None, None, None, None
@@ -197,6 +195,7 @@ class TrajOptResolver():
         sol = results['solution']
         U = results['U']
         T = results['T']
+        lam_g = results['lam_g']
 
         # Extract the values that we want from the optimizer's solution
         pos = X[:self.num_robots*2,:]               
@@ -207,7 +206,7 @@ class TrajOptResolver():
         vels = U[0::2,:]
         omegas = U[1::2,:]
 
-        return sol,pos, vels, omegas, x_vals, y_vals, theta_vals, T
+        return lam_g,sol,pos, vels, omegas, x_vals, y_vals, theta_vals, T
 
     def get_local_controls(self, controls):
         """ 
diff --git a/guided_mrmp/optimizer.py b/guided_mrmp/optimizer.py
index f71163e..e82cc8a 100644
--- a/guided_mrmp/optimizer.py
+++ b/guided_mrmp/optimizer.py
@@ -2,7 +2,7 @@ class Optimizer:
     def __init__(self, problem):
         self.problem = problem
 
-    def solve_optimization_problem(self, initial_guesses=None, solver_options=None):
+    def solve_optimization_problem(self, initial_guesses=None, solver_options=None, lam_g=None):
         opti = self.problem['opti']
 
         X = self.problem['X']
@@ -13,6 +13,9 @@ class Optimizer:
             for param, value in initial_guesses.items():
                 opti.set_initial(self.problem[param], value)
 
+        if lam_g is not None:
+            opti.set_initial(opti.lam_g, lam_g)
+
         # Set numerical backend, with options if provided
         if solver_options:
             opti.solver('ipopt', solver_options)
@@ -48,6 +51,7 @@ class Optimizer:
 
             plt.show()
             
+        
     
         # opti.callback(print_intermediates_callback)
 
@@ -62,6 +66,7 @@ class Optimizer:
             'status' : status,
             'solution' : sol,
         }
+        
 
         # print(f"Final total = {sol.value(self.problem['cost'])}")
         # print(f"robot costs = {sol.value(self.problem['robot_cost'])}")
@@ -75,4 +80,8 @@ class Optimizer:
                 if var_name != 'opti':
                     results[var_name] = sol.value(var)
 
-        return results
+        opti = self.problem['opti']
+        lam_g = sol.value(opti.lam_g)
+        results['lam_g'] = lam_g
+
+        return results,sol
diff --git a/guided_mrmp/tests/test_traj_opt.py b/guided_mrmp/tests/test_traj_opt.py
index e9d1b3f..c043772 100644
--- a/guided_mrmp/tests/test_traj_opt.py
+++ b/guided_mrmp/tests/test_traj_opt.py
@@ -352,11 +352,12 @@ if __name__ == "__main__":
     
     import time
     start = time.time()
-    sol,pos, vels, omegas, xs, ys, thetas, T = solver.solve(N, x_range, y_range, initial_guesses, solver_options)
+    old_sol, sol,pos, vels, omegas, xs, ys, thetas, T = solver.solve(N, x_range, y_range, initial_guesses, solver_options)
     end = time.time()
 
 
 
+
     # plot_paths_db(circle_obs, num_robots, robot_starts, robot_goals, pos, None, x_range, y_range, rob_radius, "Optimizer solution")
     # plot_sim(xs, ys, thetas, x_range, y_range, rob_radius, "Optimizer solution")
 
@@ -411,7 +412,7 @@ if __name__ == "__main__":
 
         import time
         start = time.time()
-        sol,pos, vels, omegas, xs, ys, thetas, T = solver.solve(N, x_range, y_range, initial_guesses, solver_options)
+        lam_g,sol,pos, vels, omegas, xs, ys, thetas, T = solver.solve(N, x_range, y_range, initial_guesses, solver_options, old_sol)
 
         end = time.time()
 
-- 
GitLab