import numpy as np
import matplotlib.pyplot as plt

from guided_mrmp.controllers.utils import compute_path_from_wp, get_ref_trajectory
from guided_mrmp.controllers.mpc import MPC
from guided_mrmp.utils import Roomba

# Classes
class PathTracker:
    def __init__(self, initial_position, dynamics, target_v, T, DT, waypoints, settings):
        """
        Initializes the PathTracker object.
        Parameters:
        - initial_position: The initial position of the robot [x, y, heading].
        - dynamics: The dynamics model of the robot.
        - target_v: The target velocity of the robot.
        - T: The time horizon for the model predictive control (MPC).
        - DT: The time step for the MPC.
        - waypoints: A list of waypoints defining the desired path.
        """
        # State of the robot [x,y, heading]
        self.state = initial_position
        self.dynamics = dynamics
        self.T = T
        self.DT = DT
        self.target_v = target_v

        # helper variable to keep track of mpc output
        # starting condition is 0,0
        self.control = np.zeros(2)

        self.K = int(T / DT)

        # For a car model 
        # Q = [20, 20, 10, 20]  # state error cost
        # Qf = [30, 30, 30, 30]  # state final error cost
        # R = [10, 10]  # input cost
        # P = [10, 10]  # input rate of change cost
        # self.mpc = MPC(VehicleModel(), T, DT, Q, Qf, R, P)


        # For a circular robot (easy dynamics)
        Q = [20, 20, 20]  # state error cost
        Qf = [30, 30, 30]  # state final error cost
        R = [10, 10]  # input cost
        P = [10, 10]  # input rate of change cost
        self.mpc = MPC(dynamics, T, DT, Q, Qf, R, P, settings['model_predictive_controller'])

        # Path from waypoint interpolation
        self.path = compute_path_from_wp(waypoints[0], waypoints[1], 0.05)

        # Helper variables to keep track of the sim
        self.sim_time = 0
        self.x_history = []
        self.y_history = []
        self.v_history = []
        self.h_history = []
        self.a_history = []
        self.d_history = []
        self.optimized_trajectory = None

        # Initialise plot
        # plt.style.use("ggplot")
        # self.fig = plt.figure()
        # plt.ion()
        # plt.show()

    def ego_to_global(self, mpc_out):
        """
        transforms optimized trajectory XY points from ego (robot) reference
        into global (map) frame

        Args:
            mpc_out ():
        """
        trajectory = np.zeros((2, self.K))
        trajectory[:, :] = mpc_out[0:2, 1:]
        Rotm = np.array(
            [
                [np.cos(self.state[3]), np.sin(self.state[3])],
                [-np.sin(self.state[3]), np.cos(self.state[3])],
            ]
        )
        trajectory = (trajectory.T.dot(Rotm)).T
        trajectory[0, :] += self.state[0]
        trajectory[1, :] += self.state[1]
        return trajectory
    
    def ego_to_global_roomba(self, mpc_out):
        """
        Transforms optimized trajectory XY points from ego (robot) reference
        into global (map) frame.

        Args:
            mpc_out (numpy array): Optimized trajectory points in ego reference frame.

        Returns:
            numpy array: Transformed trajectory points in global frame.
        """
        # Extract x, y, and theta from the state
        x = self.state[0]
        y = self.state[1]
        theta = self.state[2]

        # Rotation matrix to transform points from ego frame to global frame
        Rotm = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])

        # Initialize the trajectory array (only considering XY points)
        trajectory = mpc_out[0:2, :]

        # Apply rotation to the trajectory points
        trajectory = Rotm.dot(trajectory)

        # Translate the points to the robot's position in the global frame
        trajectory[0, :] += x
        trajectory[1, :] += y

        return trajectory

    def get_next_control(self, state, show_plots=False):
        # optimization loop
        # start=time.time()

        # Get Reference_traj -> inputs are in worldframe
        target = get_ref_trajectory(np.array(state), np.array(self.path), self.target_v, self.T, self.DT,0)

        # dynamycs w.r.t robot frame
        # curr_state = np.array([0, 0, self.state[2], 0])
        curr_state = np.array([0, 0, 0])
        x_mpc, u_mpc = self.mpc.step(
            curr_state,
            target,
            self.control
        )
        
        # only the first one is used to advance the simulation
        self.control[:] = [u_mpc[0, 0], u_mpc[1, 0]]
        # self.state = self.predict_next_state(
        #     self.state, [self.control[0], self.control[1]], DT
        # )

        return x_mpc, self.control

    def run(self, show_plots=False):
        """
        Run the path tracker algorithm.
        Parameters:
        - show_plots (bool): Flag indicating whether to show plots during the simulation. Default is False.
        Returns:
        - numpy.ndarray: Array containing the history of x, y, and h coordinates.
        """

        # Add the initial state to the histories
        self.x_history.append(self.state[0])
        self.y_history.append(self.state[1])
        self.h_history.append(self.state[2])
        if show_plots: self.plot_sim()
        
        while 1:
            if (np.sqrt((self.state[0] - self.path[0, -1]) ** 2 + (self.state[1] - self.path[1, -1]) ** 2) < 0.1):
                print("Success! Goal Reached")
                return np.asarray([self.x_history, self.y_history, self.h_history])
            x_mpc, controls = self.get_next_control(self.state)
            next_state = self.dynamics.next_state(self.state, [self.control[0], self.control[1]], self.DT)

            self.state = next_state

            self.x_history.append(self.state[0])
            self.y_history.append(self.state[1])
            self.h_history.append(self.state[2])
            
            # use the optimizer output to preview the predicted state trajectory
            # self.optimized_trajectory = self.ego_to_global(x_mpc.value)
            if show_plots: self.optimized_trajectory = self.ego_to_global_roomba(x_mpc)
            if show_plots: self.plot_sim()
            
    def plot_sim(self):
        self.sim_time = self.sim_time + self.DT
        # self.x_history.append(self.state[0])
        # self.y_history.append(self.state[1])
        # self.v_history.append(self.control[0])
        self.h_history.append(self.state[2])
        self.d_history.append(self.control[1])

        plt.clf()

        grid = plt.GridSpec(2, 3)

        plt.subplot(grid[0:2, 0:2])
        plt.title("MPC Simulation \n" + "Simulation elapsed time {}s".format(self.sim_time))

        plt.plot(
            self.path[0, :],
            self.path[1, :],
            c="tab:orange",
            marker=".",
            label="reference track",
        )

        plt.plot(
            self.x_history,
            self.y_history,
            c="tab:blue",
            marker=".",
            alpha=0.5,
            label="vehicle trajectory",
        )

        if self.optimized_trajectory is not None:
            plt.plot(
                self.optimized_trajectory[0, :],
                self.optimized_trajectory[1, :],
                c="tab:green",
                marker="+",
                alpha=0.5,
                label="mpc opt trajectory",
            )

        # plt.plot(self.x_history[-1], self.y_history[-1], c='tab:blue',
        #                                                  marker=".",
        #                                                  markersize=12,
        #                                                  label="vehicle position")
        # plt.arrow(self.x_history[-1],
        #           self.y_history[-1],
        #           np.cos(self.h_history[-1]),
        #           np.sin(self.h_history[-1]),
        #           color='tab:blue',
        #           width=0.2,
        #           head_length=0.5,
        #           label="heading")

        # plot_car(self.x_history[-1], self.y_history[-1], self.h_history[-1])
        plot_roomba(self.x_history[-1], self.y_history[-1], self.h_history[-1])


        plt.ylabel("map y")
        plt.yticks(np.arange(min(self.path[1, :]) - 1.0, max(self.path[1, :] + 1.0) + 1, 1.0))
        plt.xlabel("map x")
        plt.xticks(np.arange(min(self.path[0, :]) - 1.0, max(self.path[0, :] + 1.0) + 1, 1.0))
        plt.axis("equal")
        # plt.legend()

        plt.subplot(grid[0, 2])
        # plt.title("Linear Velocity {} m/s".format(self.v_history[-1]))
        # plt.plot(self.a_history, c="tab:orange")
        # locs, _ = plt.xticks()
        # plt.xticks(locs[1:], locs[1:] * DT)
        # plt.ylabel("a(t) [m/ss]")
        # plt.xlabel("t [s]")

        plt.subplot(grid[1, 2])
        # plt.title("Angular Velocity {} m/s".format(self.w_history[-1]))
        plt.plot(np.degrees(self.d_history), c="tab:orange")
        plt.ylabel("gamma(t) [deg]")
        locs, _ = plt.xticks()
        plt.xticks(locs[1:], locs[1:] * DT)
        plt.xlabel("t [s]")

        plt.tight_layout()

        ax = plt.gca()
        ax.set_xlim([0, 10])
        ax.set_ylim([0, 10])

        plt.draw()
        plt.pause(0.1)

def plot_roomba(x, y, yaw):
    """

    Args:
        x ():
        y ():
        yaw ():
    """
    LENGTH = 0.5  # [m]
    WIDTH = 0.25  # [m]
    OFFSET = LENGTH  # [m]

    fig = plt.gcf()
    ax = fig.gca()
    circle = plt.Circle((x, y), .5, color='b', fill=False)
    ax.add_patch(circle)

    # Plot direction marker
    dx = 1 * np.cos(yaw)
    dy = 1 * np.sin(yaw)
    ax.arrow(x, y, dx, dy, head_width=0.1, head_length=0.1, fc='r', ec='r')

if __name__ == "__main__":

    # Example usage

    file_path = "settings_files/settings.yaml"
    import yaml
    with open(file_path, 'r') as file:
        settings = yaml.safe_load(file)

    initial_pos = np.array([0.0, 0.5, 0.0, 0.0])
    dynamics = Roomba(settings)
    target_vocity = 3.0 # m/s
    T = 1  # Prediction Horizon [s]
    DT = 0.2  # discretization step [s]
    wp = [[0, 3, 4, 6, 10, 12, 13, 13, 6, 1, 0],
          [0, 0, 2, 4, 3, 3, -1, -2, -6, -2, -2]]
    sim = PathTracker(initial_position=initial_pos, dynamics=dynamics, target_v=target_vocity, T=T, DT=DT, waypoints=wp, settings=settings)
    x,y,h = sim.run(show_plots=True)