import numpy as np
# from guided_mrmp.utils import Roomba
from guided_mrmp.optimizer import Optimizer

np.seterr(divide="ignore", invalid="ignore")

import cvxpy as opt


class MPC:
    def __init__(self, model, T, DT, state_cost, final_state_cost, input_cost, input_rate_cost):
        """

        Args:
            vehicle ():
            T ():
            DT ():
            state_cost ():
            final_state_cost ():
            input_cost ():
            input_rate_cost ():
        """
        self.nx = model.state_dimension()  # number of state vars 
        self.nu = model.control_dimension()  # number of input/control vars

        if len(state_cost) != self.nx:
            raise ValueError(f"State Error cost matrix should be of size {self.nx}")
        if len(final_state_cost) != self.nx:
            raise ValueError(f"End State Error cost matrix should be of size {self.nx}")
        if len(input_cost) != self.nu:
            raise ValueError(f"Control Effort cost matrix should be of size {self.nu}")
        if len(input_rate_cost) != self.nu:
            raise ValueError(
                f"Control Effort Difference cost matrix should be of size {self.nu}"
            )

        self.robot_model = model
        self.dt = DT

        # how far we can look into the future divided by our dt 
        # is the number of control intervals
        self.control_horizon = int(T / DT) 

        # Weight for the error in state
        self.Q = np.diag(state_cost)

        # Weight for the error in final state
        self.Qf = np.diag(final_state_cost)

        # weight for error in control
        self.R = np.diag(input_cost)
        self.P = np.diag(input_rate_cost)

        # Instantiate the optimizer
        self.optimizer = Optimizer(self.nx, self.nu, self.control_horizon, self.Q, self.Qf, self.R, self.P)

 
    def get_linear_model_matrices_roomba(self,x_bar,u_bar):
        """
        Computes the approximated LTI state space model x' = Ax + Bu + C

        Args:
            x_bar (array-like): State vector [x, y, theta]
            u_bar (array-like): Input vector [v, omega]

        Returns:
            A_lin, B_lin, C_lin: Linearized state-space matrices
        """

        x = x_bar[0]
        y = x_bar[1]
        theta = x_bar[2]

        v = u_bar[0]
        omega = u_bar[1]

        ct = np.cos(theta)
        st = np.sin(theta)

        # Initialize matrix A with zeros and fill in appropriate elements
        A = np.zeros((self.nx, self.nx))
        A[0, 2] = -v * st
        A[1, 2] = v * ct

        # Discrete-time state matrix A_lin
        A_lin = np.eye(self.nx) + self.dt * A

        # Initialize matrix B with zeros and fill in appropriate elements
        B = np.zeros((self.nx, self.nu))
        B[0, 0] = ct
        B[1, 0] = st
        B[2, 1] = 1

        # Discrete-time input matrix B_lin
        B_lin = self.dt * B

        # Compute the non-linear state update equation f(x, u)
        f_xu = np.array([v * ct, v * st, omega]).reshape(self.nx, 1)

        # Compute the constant vector C_lin
        C_lin = (self.dt * (f_xu - np.dot(A, x_bar.reshape(self.nx, 1)) - np.dot(B, u_bar.reshape(self.nu, 1))).flatten())

        return A_lin, B_lin, C_lin
    
    def get_linear_model_matrices(self, x_bar, u_bar):
        """
        Computes the approximated LTI state space model x' = Ax + Bu + C

        Args:
            x_bar (array-like):
            u_bar (array-like):

        Returns:

        """

        x = x_bar[0]
        y = x_bar[1]
        v = x_bar[2]
        theta = x_bar[3]

        a = u_bar[0]
        delta = u_bar[1]

        ct = np.cos(theta)
        st = np.sin(theta)
        cd = np.cos(delta)
        td = np.tan(delta)

        A = np.zeros((self.nx, self.nx))
        A[0, 2] = ct
        A[0, 3] = -v * st
        A[1, 2] = st
        A[1, 3] = v * ct
        A[3, 2] = v * td / self.robot_model.wheelbase
        A_lin = np.eye(self.nx) + self.dt * A

        B = np.zeros((self.nx, self.nu))
        B[2, 0] = 1
        B[3, 1] = v / (self.robot_model.wheelbase * cd**2)
        B_lin = self.dt * B

        f_xu = np.array([v * ct, v * st, a, v * td / self.robot_model.wheelbase]).reshape(
            self.nx, 1
        )
        C_lin = (
            self.dt
            * (
                f_xu
                - np.dot(A, x_bar.reshape(self.nx, 1))
                - np.dot(B, u_bar.reshape(self.nu, 1))
            ).flatten()
        )
        return A_lin, B_lin, C_lin

    def step(self, initial_state, target, prev_cmd):
        A, B, C = self.get_linear_model_matrices_roomba(initial_state, prev_cmd)  # Use Roomba model

        # Use the Optimizer class to solve the optimization problem
        x_opt, u_opt = self.optimizer.solve(initial_state, target, prev_cmd, A, B, C, self.robot_model, self.dt)

        return x_opt, u_opt
    
if __name__ == "__main__":
    # Example usage:
    dt = 0.1
    roomba = Roomba()
    Q = [20, 20, 20]  # state error cost
    Qf = [30, 30, 30]  # state final error cost
    R = [10, 10]  # input cost
    P = [10, 10]  # input rate of change cost
    mpc = MPC(roomba, 5, dt, Q, Qf, R, P)
    x_bar = np.array([0.0, 0.0, 0.0])
    u_bar = np.array([1.0, 0.1])
    A_lin, B_lin, C_lin = mpc.get_linear_model_matrices_roomba(x_bar, u_bar)
    print(A_lin)
    print(B_lin)
    print(C_lin)