Skip to content
Snippets Groups Projects
optimizer.py 4.26 KiB
class Optimizer:
    def __init__(self, problem):
        self.problem = problem

    def solve_optimization_problem(self, initial_guesses=None, solver_options=None, lam_g=None):
        opti = self.problem['opti']

        X = self.problem['X']
        U = self.problem['U']
        
        if initial_guesses:
            for param, value in initial_guesses.items():
                opti.set_initial(self.problem[param], value)

        if lam_g is not None:
            opti.set_initial(opti.lam_g, lam_g)

        # Set numerical backend, with options if provided
        if solver_options:
            opti.solver('ipopt', solver_options)
        else:
            opti.solver('ipopt')

        def print_intermediates_callback(i):
            # print the current value of the objective function
            print("Iteration:", i, "Current cost cost:", opti.debug.value(self.problem['cost']))
            print("Iteration:", i, "Current robot cost:", opti.debug.value(self.problem['dist_to_other_robots']))
            # print("Iteration:", i, "Current obstacle cost:", opti.debug.value(self.problem['obs_cost']))
            # print("Iteration:", i, "Current control cost:", opti.debug.value(self.problem['control_cost']))
            # print("Iteration:", i, "Current time cost:", opti.debug.value(self.problem['time_cost']))
            # print("Iteration:", i, "Current goal cost:", opti.debug.value(self.problem['goal_cost']))

            # print("Iteration:", i, "Current solution:", opti.debug.value(X), opti.debug.value(U))
            # X_debug = opti.debug.value(X)
            # U_debug = opti.debug.value(U)

            # plot the state and the control 
            # split a figure in half. The left side will show the positions, the right side will plot the controls
            # X[i*3, :] is the ith robot's x position, X[i*3+1, :] is the y position, X[i*3+2, :] is the heading
            # U[i*2, :] is the ith robot's linear velocity, U[i*2+1, :] is the ith robot's angular velocity
            # import matplotlib.pyplot as plt
            # fig, axs = plt.subplots(1, 2, figsize=(12, 6))
            # for j in range(X_debug.shape[0]//3):
            #     axs[0].plot(X_debug[j*3, :], X_debug[j*3+1, :], label=f"Robot {j}")
            #     axs[0].scatter(X_debug[j*3, 0], X_debug[j*3+1, 0], color='green')
            #     axs[0].scatter(X_debug[j*3, -1], X_debug[j*3+1, -1], color='red')
            #     axs[0].set_title("Robot Positions")
            #     axs[0].set_xlabel("X")
            #     axs[0].set_ylabel("Y")
            #     axs[0].legend()

            #     axs[1].plot(U_debug[j*2, :], label=f"Robot {j} velocity")
            #     axs[1].plot(U_debug[j*2+1, :], label=f"Robot {j} omega")
            #     axs[1].set_title("Robot Controls")
            #     axs[1].set_xlabel("Time")
            #     axs[1].set_ylabel("Control")
            #     axs[1].legend()

            # plt.show()
            

        # opti.callback(print_intermediates_callback)

        # sol = opti.solve()

        # print("/solving optimization problem")

        # import time
        # start = time.time()
        try:
            sol = opti.solve()   # actual solve
            status = 'succeeded'
        except:
            sol = None
            status = 'failed'
        # end = time.time()
        # print(f"Time taken to solve optimization problem = {end - start}")

        results = {
            'status' : status,
            'solution' : sol,
        }
        

        # print(f"Final total = {sol.value(self.problem['cost'])}")
        # print(f"robot costs = {sol.value(self.problem['robot_cost'])}")
        # print(f"obstacle costs = {sol.value(self.problem['obs_cost'])}")
        # print(f"control costs = {sol.value(self.problem['control_cost'])}")
        # print(f"time costs = {sol.value(self.problem['time_cost'])}")
        # print(f"goal costs = {sol.value(self.problem['goal_cost'])}")

        if sol:
            for var_name, var in self.problem.items():
                if var_name != 'opti':
                    try: 
                        results[var_name] = sol.value(var)
                    except:
                        results[var_name] = var

        opti = self.problem['opti']
        lam_g = sol.value(opti.lam_g)
        results['lam_g'] = lam_g

        return results