import cvxpy as opt
import numpy as np

class Optimizer:
    def __init__(self, nx, nu, control_horizon, Q, Qf, R, P):
        self.nx = nx
        self.nu = nu
        self.control_horizon = control_horizon
        self.Q = Q
        self.Qf = Qf
        self.R = R
        self.P = P

    def solve(self, initial_state, target, prev_cmd, A, B, C, robot_model, dt):
        """
        Sets up and solves the optimization problem.

        Args:
            initial_state (array-like): current estimate of [x, y, heading]
            target (ndarray): state space reference, in the same frame as the provided current state
            prev_cmd (array-like): previous [v, delta]
            A, B, C: Linearized state-space matrices
            robot_model: Robot model containing constraints
            dt: Time step

        Returns:
            x, u: Optimal state and input trajectories
        """

        # set up variables for the optimization problem
        x = opt.Variable((self.nx, self.control_horizon + 1), name="states")
        u = opt.Variable((self.nu, self.control_horizon), name="actions")
        cost = 0
        constr = []

        # Tracking error cost
        for k in range(self.control_horizon):
            cost += opt.quad_form(x[:, k + 1] - target[:, k], self.Q)

        # Final point tracking cost
        cost += opt.quad_form(x[:, -1] - target[:, -1], self.Qf)

        # Actuation magnitude cost
        for k in range(self.control_horizon):
            cost += opt.quad_form(u[:, k], self.R)

        # Actuation rate of change cost
        for k in range(1, self.control_horizon):
            cost += opt.quad_form(u[:, k] - u[:, k - 1], self.P)

        # Kinematics Constraints
        for k in range(self.control_horizon):
            constr += [x[:, k + 1] == A @ x[:, k] + B @ u[:, k] + C]

        # initial state
        constr += [x[:, 0] == initial_state]

        # actuation bounds
        constr += [opt.abs(u[:, 0]) <= robot_model.max_acc]
        constr += [opt.abs(u[:, 1]) <= robot_model.max_steer]

        # Actuation rate of change bounds
        constr += [opt.abs(u[0, 0] - prev_cmd[0]) / dt <= robot_model.max_d_acc]
        constr += [opt.abs(u[1, 0] - prev_cmd[1]) / dt <= robot_model.max_d_steer]
        for k in range(1, self.control_horizon):
            constr += [opt.abs(u[0, k] - u[0, k - 1]) / dt <= robot_model.max_d_acc]
            constr += [opt.abs(u[1, k] - u[1, k - 1]) / dt <= robot_model.max_d_steer]

        prob = opt.Problem(opt.Minimize(cost), constr)
        solution = prob.solve(solver=opt.OSQP, warm_start=True, verbose=False)
        return x, u