import numpy as np
# from guided_mrmp.utils import Roomba
from guided_mrmp.optimizers.optimizer import Optimizer, iLQR

np.seterr(divide="ignore", invalid="ignore")

import cvxpy as opt


class MPC:
    def __init__(self, model, T, DT, state_cost, final_state_cost, input_cost, input_rate_cost):
        """

        Args:
            vehicle ():
            T ():
            DT ():
            state_cost ():
            final_state_cost ():
            input_cost ():
            input_rate_cost ():
        """
        self.nx = model.state_dimension()  # number of state vars 
        self.nu = model.control_dimension()  # number of input/control vars

        self.robot_model = model
        self.dt = DT

        # how far we can look into the future divided by our dt 
        # is the number of control intervals
        self.control_horizon = int(T / DT) 

        # Weight for the error in state
        self.Q = np.diag(state_cost)

        # Weight for the error in final state
        self.Qf = np.diag(final_state_cost)

        # weight for error in control
        self.R = np.diag(input_cost)
        self.P = np.diag(input_rate_cost)

        # Instantiate the optimizer
        self.optimizer = iLQR(self.nx, self.nu, self.control_horizon)

    def cost_function(self, x, u, target):
        """
        Calculate the cost function for the optimizer.
        Parameters:
            - target: (numpy.ndarray) The target values for tracking.
            - u: (numpy.ndarray) The control inputs.
            - x: (numpy.ndarray) The state variables.
        Returns:
            - cost: (float) The calculated cost.
        """
        cost = 0

        

        # Tracking error cost
        for k in range(self.control_horizon):
            cost += opt.quad_form(x[:, k + 1] - target[:, k], self.Q)

        # Final point tracking cost
        cost += opt.quad_form(x[:, -1] - target[:, -1], self.Qf)

        # Actuation magnitude cost
        for k in range(self.control_horizon):
            cost += opt.quad_form(u[:, k], self.R)

        # Actuation rate of change cost
        for k in range(1, self.control_horizon):
            cost += opt.quad_form(u[:, k] - u[:, k - 1], self.P)

        return cost

    def constraints(self, initial_state, A, B,C, x, u, robot_model, dt, prev_cmd):

        constr = []
        
        # Kinematics Constraints
        for k in range(self.control_horizon):
            constr += [x[:, k + 1] == A @ x[:, k] + B @ u[:, k] + C]

        # initial state
        constr += [x[:, 0] == initial_state]

        # actuation bounds
        constr += [opt.abs(u[:, 0]) <= robot_model.max_acc]
        constr += [opt.abs(u[:, 1]) <= robot_model.max_steer]

        # Actuation rate of change bounds
        constr += [opt.abs(u[0, 0] - prev_cmd[0]) / dt <= robot_model.max_d_acc]
        constr += [opt.abs(u[1, 0] - prev_cmd[1]) / dt <= robot_model.max_d_steer]
        for k in range(1, self.control_horizon):
            constr += [opt.abs(u[0, k] - u[0, k - 1]) / dt <= robot_model.max_d_acc]
            constr += [opt.abs(u[1, k] - u[1, k - 1]) / dt <= robot_model.max_d_steer]

        return constr

    def step(self, initial_state, target, prev_cmd, initial_guess=None):
        """
        Sets up and solves the optimization problem.
        
        Args:
            initial_state: Current estimate of [x, y, heading]
            target: State space reference, in the same frame as the provided current state
            prev_cmd: Previous [v, delta]
            A, B, C: Linearized state-space matrices
            initial_guess: Optional initial guess for the optimizer
        
        Returns:
            x_opt: Optimal state trajectory
            u_opt: Optimal control trajectory
        """
        A,B,C = self.robot_model.linearize(initial_state, prev_cmd, self.dt)
        
        # set up variables for the optimization problem
        x = opt.Variable((self.nx, self.control_horizon + 1))
        u = opt.Variable((self.nu, self.control_horizon))
        
        # Define the cost function
        cost = self.cost_function(x, u, target)
        
        # Define the constraints
        constraints = self.constraints(initial_state, A, B, C, x, u, self.robot_model, self.dt, prev_cmd)

        # Use the Optimizer class to solve the optimization problem
        x_opt, u_opt = self.optimizer.solve(x,u,cost, constraints, initial_guess)

        return x_opt, u_opt