import numpy as np
from guided_mrmp.utils import Roomba

np.seterr(divide="ignore", invalid="ignore")

import cvxpy as opt


class MPC:
    def __init__(self, model, T, DT, state_cost, final_state_cost, input_cost, input_rate_cost):
        """

        Args:
            vehicle ():
            T ():
            DT ():
            state_cost ():
            final_state_cost ():
            input_cost ():
            input_rate_cost ():
        """
        self.nx = 3  # number of state vars 
        self.nu = 2  # number of input/control vars

        if len(state_cost) != self.nx:
            raise ValueError(f"State Error cost matrix shuld be of size {self.nx}")
        if len(final_state_cost) != self.nx:
            raise ValueError(f"End State Error cost matrix shuld be of size {self.nx}")
        if len(input_cost) != self.nu:
            raise ValueError(f"Control Effort cost matrix shuld be of size {self.nu}")
        if len(input_rate_cost) != self.nu:
            raise ValueError(
                f"Control Effort Difference cost matrix shuld be of size {self.nu}"
            )

        self.robot_model = model
        self.dt = DT

        # how far we can look into the future divided by our dt 
        # is the number of control intervals
        self.control_horizon = int(T / DT) 

        # Weight for the error in state
        self.Q = np.diag(state_cost)

        # Weight for the error in final state
        self.Qf = np.diag(final_state_cost)

        # weight for error in control
        self.R = np.diag(input_cost)
        self.P = np.diag(input_rate_cost)
 
    def get_linear_model_matrices_roomba(self,x_bar,u_bar):
        """
        Computes the approximated LTI state space model x' = Ax + Bu + C

        Args:
            x_bar (array-like): State vector [x, y, theta]
            u_bar (array-like): Input vector [v, omega]

        Returns:
            A_lin, B_lin, C_lin: Linearized state-space matrices
        """

        x = x_bar[0]
        y = x_bar[1]
        theta = x_bar[2]

        v = u_bar[0]
        omega = u_bar[1]

        ct = np.cos(theta)
        st = np.sin(theta)

        # Initialize matrix A with zeros and fill in appropriate elements
        A = np.zeros((self.nx, self.nx))
        A[0, 2] = -v * st
        A[1, 2] = v * ct

        # Discrete-time state matrix A_lin
        A_lin = np.eye(self.nx) + self.dt * A

        # Initialize matrix B with zeros and fill in appropriate elements
        B = np.zeros((self.nx, self.nu))
        B[0, 0] = ct
        B[1, 0] = st
        B[2, 1] = 1

        # Discrete-time input matrix B_lin
        B_lin = self.dt * B

        # Compute the non-linear state update equation f(x, u)
        f_xu = np.array([v * ct, v * st, omega]).reshape(self.nx, 1)

        # Compute the constant vector C_lin
        C_lin = (self.dt * (f_xu - np.dot(A, x_bar.reshape(self.nx, 1)) - np.dot(B, u_bar.reshape(self.nu, 1))).flatten())

        return A_lin, B_lin, C_lin
    
    def get_linear_model_matrices(self, x_bar, u_bar):
        """
        Computes the approximated LTI state space model x' = Ax + Bu + C

        Args:
            x_bar (array-like):
            u_bar (array-like):

        Returns:

        """

        x = x_bar[0]
        y = x_bar[1]
        v = x_bar[2]
        theta = x_bar[3]

        a = u_bar[0]
        delta = u_bar[1]

        ct = np.cos(theta)
        st = np.sin(theta)
        cd = np.cos(delta)
        td = np.tan(delta)

        A = np.zeros((self.nx, self.nx))
        A[0, 2] = ct
        A[0, 3] = -v * st
        A[1, 2] = st
        A[1, 3] = v * ct
        A[3, 2] = v * td / self.robot_model.wheelbase
        A_lin = np.eye(self.nx) + self.dt * A

        B = np.zeros((self.nx, self.nu))
        B[2, 0] = 1
        B[3, 1] = v / (self.robot_model.wheelbase * cd**2)
        B_lin = self.dt * B

        f_xu = np.array([v * ct, v * st, a, v * td / self.robot_model.wheelbase]).reshape(
            self.nx, 1
        )
        C_lin = (
            self.dt
            * (
                f_xu
                - np.dot(A, x_bar.reshape(self.nx, 1))
                - np.dot(B, u_bar.reshape(self.nu, 1))
            ).flatten()
        )
        return A_lin, B_lin, C_lin

    def step(self, initial_state, target, prev_cmd):
        """

        Args:
            initial_state (array-like): current estimate of [x, y, heading]
            target (ndarray): state space reference, in the same frame as the provided current state
            prev_cmd (array-like): previous [v, delta].

        Returns:

        """
        assert len(initial_state) == self.nx
        assert len(prev_cmd) == self.nu
        assert target.shape == (self.nx, self.control_horizon)

        # Create variables needed for setting up cvxpy problem
        x = opt.Variable((self.nx, self.control_horizon + 1), name="states")
        u = opt.Variable((self.nu, self.control_horizon), name="actions")
        cost = 0
        constr = []

        # NOTE: here the state linearization is performed around the starting condition to simplify the controller.
        # This approximation gets more inaccurate as the controller looks at the future.
        # To improve performance we can keep track of previous optimized x, u and compute these matrices for each timestep k
        # Ak, Bk, Ck = self.get_linear_model_matrices(x_prev[:,k], u_prev[:,k])
        A, B, C = self.get_linear_model_matrices_roomba(initial_state, prev_cmd) # for a differential drive roomba


        # Tracking error cost
        # we want the difference bt our state and the target to be small
        for k in range(self.control_horizon):
            cost += opt.quad_form(x[:, k + 1] - target[:, k], self.Q)


        # Final point tracking cost
        # we want the final goals to match up
        cost += opt.quad_form(x[:, -1] - target[:, -1], self.Qf)

        # Actuation magnitude cost
        # we want the controls to be small
        for k in range(self.control_horizon):
            cost += opt.quad_form(u[:, k], self.R)

        # Actuation rate of change cost
        # we want the difference in controls between time steps to be small
        for k in range(1, self.control_horizon):
            cost += opt.quad_form(u[:, k] - u[:, k - 1], self.P)

        # Kinematics Constraints
        # Need to obey the kinematics of the robot x_{k+1} = A*x_k + B*u_k + C
        for k in range(self.control_horizon):
            constr += [x[:, k + 1] == A @ x[:, k] + B @ u[:, k] + C]

        # initial state
        constr += [x[:, 0] == initial_state]

        # actuation bounds
        constr += [opt.abs(u[:, 0]) <= self.robot_model.max_acc]
        constr += [opt.abs(u[:, 1]) <= self.robot_model.max_steer]

        # Actuation rate of change bounds
        constr += [opt.abs(u[0, 0] - prev_cmd[0]) / self.dt <= self.robot_model.max_d_acc]
        constr += [opt.abs(u[1, 0] - prev_cmd[1]) / self.dt <= self.robot_model.max_d_steer]
        for k in range(1, self.control_horizon):
            constr += [opt.abs(u[0, k] - u[0, k - 1]) / self.dt <= self.robot_model.max_d_acc]
            constr += [opt.abs(u[1, k] - u[1, k - 1]) / self.dt <= self.robot_model.max_d_steer]

        prob = opt.Problem(opt.Minimize(cost), constr)
        solution = prob.solve(solver=opt.OSQP, warm_start=True, verbose=False)
        return x, u
    
if __name__ == "__main__":
    # Example usage:
    dt = 0.1
    roomba = Roomba()
    Q = [20, 20, 20]  # state error cost
    Qf = [30, 30, 30]  # state final error cost
    R = [10, 10]  # input cost
    P = [10, 10]  # input rate of change cost
    mpc = MPC(roomba, 5, dt, Q, Qf, R, P)
    x_bar = np.array([0.0, 0.0, 0.0])
    u_bar = np.array([1.0, 0.1])
    A_lin, B_lin, C_lin = mpc.get_linear_model_matrices_roomba(x_bar, u_bar)
    print(A_lin)
    print(B_lin)
    print(C_lin)