import numpy as np
import casadi as ca
from guided_mrmp.optimizer import Optimizer

np.seterr(divide="ignore", invalid="ignore")

class MultiMPC:
    def __init__(self, num_robots, model, T, DT, state_cost, final_state_cost, input_cost, input_rate_cost, settings, circle_obs):
        """
        Initializes the MPC controller.
        """
        self.nx = model.state_dimension()  # number of state vars 
        self.nu = model.control_dimension()  # number of input/control vars
        self.num_robots = num_robots
        self.robot_radius = model.radius

        self.robot_model = model
        self.dt = DT

        self.circle_obs = circle_obs

        # how far we can look into the future divided by our dt 
        # is the number of control intervals
        self.control_horizon = int(T / DT) 

        # Weight for the error in state
        self.Q = np.diag(state_cost)

        # Weight for the error in final state
        self.Qf = np.diag(final_state_cost)

        # weight for error in control
        self.R = np.diag(input_cost)
        self.P = np.diag(input_rate_cost)

        self.acceptable_tol = settings['acceptable_tol']
        self.acceptable_iter = settings['acceptable_iter']
        self.print_level = settings['print_level']
        self.print_time = settings['print_time']

    def apply_quadratic_barrier(self, d_max, d, c):
        """
        Applies a quadratic barrier to some given distance. The quadratic barrier 
        is a soft barrier function. We are using it for now to avoid any issues with
        invalid initial solutions, which hard barrier functions cannot handle. 

        params:
            d (float):      distance to the obstacle
            c (float):      controls the steepness of curve. 
                            higher c --> gets more expensive faster as you move toward obs
            d_max (float):  The threshold distance at which the barrier starts to apply 
        """
        return c*ca.fmax(0, (d_max-d)**2)

    def setup_mpc_problem(self, initial_state, target, prev_cmd, As, Bs, Cs):
        """
        Create the cost function and constraints for the optimization problem.

        inputs:
            - initial_state (nx3 array): Initial state for each robot
            - target : Target state for each robot
            - prev_cmd: Previous control input for each robot
            - As: List of A matrices for each robot
            - Bs: List of B matrices for each robot
            - Cs: List of C matrices for each robot
        """

        opti = ca.Opti()

        # Decision variables
        X = opti.variable(self.nx*self.num_robots, self.control_horizon + 1)
        pos = X[:self.num_robots*2,:]               # position is the first two values
        x = pos[0::2,:]
        y = pos[1::2,:]
        heading = X[self.num_robots*2:,:]           # heading is the last value

        U = opti.variable(self.nu*self.num_robots, self.control_horizon)

        # Parameters
        initial_state = ca.MX(initial_state)

        # print(f"target = {target}")
        # target = target
        # prev_cmd = ca.MX(prev_cmd)
        # As = ca.MX(As)
        # Bs = ca.MX(Bs)
        # Cs = ca.MX(Cs)

        # Cost function
        cost = 0
        for k in range(self.control_horizon):
            for i in range(self.num_robots):# 0, 3 # 3,6
                # print(f"k = {k}/{self.control_horizon-1}")
                # print(f"target a = {target[i]}")
                # print(f"target b = {target[i][:][k]}")
                # # print(f"target c = {target[i][:][k]}")
                this_target = [target[i][0][k], target[i][1][k], target[i][2][k]]
                # print(f"this_target = {this_target}")
                # difference between the current state and the target state
                cost += ca.mtimes([(X[i*3 : i*3 +3, k+1] - this_target).T, self.Q, X[i*3 : i*3 +3, k+1] - this_target])
            

                # control effort
                cost += ca.mtimes([U[i*2:i*2+2, k].T, self.R, U[i*2:i*2+2, k]])
            if k > 0:
                # Penalize large changes in control
                cost += ca.mtimes([(U[i*2:i*2+2, k] - U[i*2:i*2+2, k-1]).T, self.P, U[i*2:i*2+2, k] - U[i*2:i*2+2, k-1]])


        # Final state cost
        for i in range(self.num_robots):
            final_target = this_target = [target[i][0][-1], target[i][1][-1], target[i][2][-1]]
            cost += ca.mtimes([(X[i*3 : i*3 +3, -1] - final_target).T, self.Qf, X[i*3 : i*3 +3, -1] - final_target])

        # robot-robot collision cost
        dist_to_other_robots = 0
        for k in range(self.control_horizon):
            for r1 in range(self.num_robots):
                for r2 in range(r1+1, self.num_robots):
                    if r1 != r2:
                        d = ca.sumsqr(pos[2*r1 : 2*r1+1, k] - pos[2*r2 : 2*r2+1, k]) 
                        d = ca.sqrt(d)
                        dist_to_other_robots += self.apply_quadratic_barrier(6*self.robot_radius, d-self.robot_radius*2, 1)

        # obstacle collision cost
        obstacle_cost = 0
        for k in range(self.control_horizon):
            for i in range(self.num_robots):
                for obs in self.circle_obs:
                    d = ca.sumsqr(x[i, k] - obs[0]) + ca.sumsqr(y[i, k] - obs[1])
                    d = ca.sqrt(d)
                    obstacle_cost += self.apply_quadratic_barrier(6*self.robot_radius, d-self.robot_radius*2, 1)

        opti.minimize(cost + .5*dist_to_other_robots + .5*obstacle_cost)

        # Constraints
        for i in range(self.num_robots):
            for k in range(self.control_horizon):
                A = ca.MX(As[i])
                B = ca.MX(Bs[i])
                C = ca.MX(Cs[i])
                opti.subject_to(X[i*3:i*3+3, k+1] == ca.mtimes(A, X[i*3:i*3+3, k]) + ca.mtimes(B, U[i*2:i*2+2, k]) + C)

        for i in range(self.num_robots):
            opti.subject_to(X[i*3:i*3+3, 0] == initial_state[i])

        for i in range(self.num_robots):
            opti.subject_to(opti.bounded(-self.robot_model.max_acc, U[i*2:i*2+2, :], self.robot_model.max_acc))
            opti.subject_to(ca.fabs(U[i*2, 0] - prev_cmd[i][0]) / self.dt <= self.robot_model.max_d_acc)
            opti.subject_to(ca.fabs(U[i*2+1, 0] - prev_cmd[i][1]) / self.dt <= self.robot_model.max_d_steer)

            for k in range(1, self.control_horizon):
                opti.subject_to(ca.fabs(U[i*2, k] - U[i*2, k-1]) / self.dt <= self.robot_model.max_d_acc)
                opti.subject_to(ca.fabs(U[i*2+1, k] - U[i*2+1, k-1]) / self.dt <= self.robot_model.max_d_steer)


        return {
            'opti': opti,
            'X': X,
            'U': U,
            'initial_state': initial_state,
            'target': target,
            'prev_cmd': prev_cmd,
            'cost': cost,
            'dist_to_other_robots': dist_to_other_robots
        }

    def solve_optimization_problem(self, problem, initial_guesses=None, solver_options=None):
        opt = Optimizer(problem)
        results = opt.solve_optimization_problem(initial_guesses, solver_options)
        return results

    def step(self, initial_state, target, prev_cmd, initial_guesses=None):
        """
        Sets up and solves the optimization problem.
        
        Args:
            initial_state: List of current estimates of [x, y, heading] for each robot
            target: State space reference, in the same frame as the provided current state
            prev_cmd: List of previous commands [v, delta] for all robots  
            initial_guess: Optional initial guess for the optimizer
        
        Returns:
            x_opt: Optimal state trajectory
            u_opt: Optimal control trajectory
        """
        As, Bs, Cs = [], [], []
        for i in range(self.num_robots):
            # print(f"initial_state[i] = {initial_state[i]}")
            # print(f"prev_cmd[i] = {prev_cmd[i]}")
            A, B, C = self.robot_model.linearize(initial_state[i], prev_cmd[i], self.dt)
            As.append(A)
            Bs.append(B)
            Cs.append(C)

        solver_options = {'ipopt.print_level': self.print_level, 
                          'print_time': self.print_time, 
                        #   'ipopt.tol': 1e-3,
                          'ipopt.acceptable_tol': self.acceptable_tol, 
                          'ipopt.acceptable_iter': self.acceptable_iter}

        problem = self.setup_mpc_problem(initial_state, target, prev_cmd, As, Bs, Cs)

        result = self.solve_optimization_problem(problem, initial_guesses, solver_options)

        if result['status'] == 'succeeded':
            x_opt = result['X']
            u_opt = result['U']
        else:
            print("Optimization failed")
            x_opt = None
            u_opt = None

        return x_opt, u_opt