from guided_mrmp.controllers.path_tracker import *

from guided_mrmp.planners.singlerobot.RRTStar import RRTStar


from guided_mrmp.utils import Roomba, Env

def plot(x_histories, y_histories, h_histories, wp_paths):
    plt.style.use("ggplot")
    fig = plt.figure()
    plt.ion()
    plt.show()
    plt.clf()

    print(f"hist = {x_histories}")

    for x_history, y_history, h_history, path in zip(x_histories, y_histories, h_histories, wp_paths):
        print(x_history)
        plt.plot(
            path[0, :],
            path[1, :],
            c="tab:orange",
            marker=".",
            label="reference track",
        )

        plt.plot(
            x_history,
            y_history,
            c="tab:blue",
            marker=".",
            alpha=0.5,
            label="vehicle trajectory",
        )

        plot_roomba(x_history[-1], y_history[-1], h_history[-1])


        plt.ylabel("y")
        plt.yticks(
            np.arange(min(path[1, :]) - 1.0, max(path[1, :] + 1.0) + 1, 1.0)
        )
        plt.xlabel("x")
        plt.xticks(
            np.arange(min(path[0, :]) - 1.0, max(path[0, :] + 1.0) + 1, 1.0)
        )
    plt.axis("equal")
    

    plt.tight_layout()

    plt.draw()
    plt.pause(0.1)
    input()

def get_traj_from_points(start, dynamics, target_v, T, DT, waypoints):
    
    sim = PathTracker(initial_position=start, dynamics=dynamics,target_v=target_v, T=T, DT=DT, waypoints=waypoints)
    x,y,h = sim.run(show_plots=False)
    path = sim.path
    return x,y,h,path


def plot_sim(x_histories, y_histories, h_histories, wp_paths):

    if len(x_histories) > 20:
        colors = plt.cm.hsv(np.linspace(0.2, 1.0, len(x_histories)))
    elif len(x_histories) > 10:
        colors = plt.cm.tab20(np.linspace(0, 1, len(x_histories)))
    else:
        colors = plt.cm.tab10(np.linspace(0, 1, len(x_histories)))
    plt.clf()

    longest_traj = max([len(x) for x in x_histories])

    for i in range(longest_traj):
        plt.clf()
        for x_history, y_history, h_history, path, color in zip(x_histories, y_histories, h_histories, wp_paths, colors):
            plt.plot(
                path[0, :],
                path[1, :],
                c=color,
                marker=".",
                markersize=1,
                label="reference track",
            )

            plt.plot(
                x_history[:i],
                y_history[:i],
                c=color,
                marker=".",
                alpha=0.5,
                label="vehicle trajectory",
            )

            if i < len(x_history):
                plot_roomba(x_history[i], y_history[i], h_history[i], color)
            else:
                plot_roomba(x_history[-1], y_history[-1], h_history[-1], color)

        # Set x-axis range
        plt.xlim(-10, 10)

        # Set y-axis range
        plt.ylim(-10, 10)
        plt.axis("equal")
        plt.tight_layout()

        plt.draw()
        plt.pause(0.1)
    input()



def plot_roomba(x, y, yaw, color):
    """

    Args:
        x ():
        y ():
        yaw ():
    """
    LENGTH = 0.5  # [m]
    WIDTH = 0.25  # [m]
    OFFSET = LENGTH  # [m]

    fig = plt.gcf()
    ax = fig.gca()
    circle = plt.Circle((x, y), .5, color=color, fill=False)
    ax.add_patch(circle)

    # Plot direction marker
    dx = 1 * np.cos(yaw)
    dy = 1 * np.sin(yaw)
    ax.arrow(x, y, dx, dy, head_width=0.1, head_length=0.1, fc='r', ec='r')



if __name__ == "__main__":

    initial_pos_1 = np.array([0.0, 0.1, 0.0, 0.0])
    target_vocity = 3.0 # m/s
    T = 1  # Prediction Horizon [s]
    DT = 0.2  # discretization step [s]


    x_start = (0, 0)  # Starting node
    x_goal = (10, 3)  # Goal node

    env = Env([0,10], [0,10], [], [])

    rrtstar = RRTStar(env, x_start, x_goal, 0.5, 0.05, 1000, r=2.0)
    rrtstarpath = rrtstar.run()
    rrtstarpath = list(reversed(rrtstarpath))
    xs = []
    ys = []
    for node in rrtstarpath:
        xs.append(node[0])
        ys.append(node[1])

    dynamics = Roomba()
    wp_1 = [xs,ys]
    
    x1,y1,h1,path1 = get_traj_from_points(initial_pos_1, dynamics, target_vocity, T, DT, wp_1)
    
    initial_pos_2 = np.array([10.0, 5.1, 0.0, 0.0])
    target_vocity = 3.0 # m/s


    x_start = (10, 5)  # Starting node
    x_goal = (1, 1)  # Goal node
    rrtstar = RRTStar(env,x_start, x_goal, 0.5, 0.05, 500, r=2.0)
    rrtstarpath = rrtstar.run()
    rrtstarpath = list(reversed(rrtstarpath))
    xs = []
    ys = []
    for node in rrtstarpath:
        xs.append(node[0])
        ys.append(node[1])

    wp_2 = [xs,ys]

    
    x2,y2,h2,path2 = get_traj_from_points(initial_pos_2, dynamics,target_vocity, T, DT, wp_2)


    plot([x1,x2], [y1,y2], [h1,h2], [path1, path2])
    plot_sim([x1,x2], [y1,y2], [h1,h2], [path1, path2])