import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Rectangle
from casadi import *
from guided_mrmp.conflict_resolvers.curve_path import smooth_path, calculate_headings

from guided_mrmp.conflict_resolvers.traj_opt_resolver import TrajOptResolver

def plot_paths_db(circle_obs, num_robots, starts, goals, x_opt, initial_guess,x_range, y_range, radius, title=""):
        fig, ax = plt.subplots()

        # Plot obstacles
        for obstacle in circle_obs:
            # if len(obstacleq) == 2:  # Circle
            ax.add_patch(Circle(obstacle, obstacle[2], color='red'))
            # elif len(obstacle) == 4:  # Rectangle
            #     ax.add_patch(Rectangle((obstacle[0], obstacle[1]), obstacle[2], obstacle[3], color='red'))

        colors = plt.cm.Set1(np.linspace(0, 1, num_robots))

        # Plot robot paths
        for r,color in zip(range(num_robots),colors):
            if x_opt is not None:
                ax.plot(x_opt[r*2, :], x_opt[r*2+1, :], label=f'Robot {r+1}', color=color)
                ax.scatter(x_opt[r*2, :], x_opt[r*2+1, :], color=color, s=10 )

                

            ax.scatter(starts[r][0], starts[r][1], s=85,color=color)
            ax.scatter(goals[r][0], goals[r][1], s=135,facecolors='none', edgecolors=color)
            if initial_guess is not None:
                ax.plot(initial_guess[r*3, :], initial_guess[r*3+1, :], color=color, linestyle='--')
                ax.scatter(initial_guess[r*3, :], initial_guess[r*3+1, :], color=color, s=5 )

                x = initial_guess[r*3, :]
                y = initial_guess[r*3+1, :]
                heading = initial_guess[r*3+2, :]
                for x0,y0,heading0 in zip(x,y,heading):
                    dx = 0.1 * np.cos(heading0)
                    dy = 0.1 * np.sin(heading0)
                    ax.arrow(x0, y0, dx, dy, head_width=0.05, head_length=0.1, fc=color, ec=color)

            if x_opt is not None: plot_roomba(starts[r][0], starts[r][1], 0,color, radius)
            # plot_roomba(self.goals[r][0], self.goals[r][1], 0, color)



        plt.ylim(0, y_range[1])
        plt.xlim(0,x_range[1])
        plt.axis("equal")
        plt.title(title)
        # plt.axis("off")
        

        plt.tight_layout()

        plt.grid(False)
        plt.show()

def plot_sim(x_histories, y_histories, h_histories, x_range, y_range, radius, title=""):

    x_histories = np.array(x_histories)
    y_histories = np.array(y_histories)
    h_histories = np.array(h_histories)

    colors = plt.cm.Set1(np.linspace(0, 1, len(x_histories)))


    longest_traj = max([len(x) for x in x_histories])

    for i in range(longest_traj):
        plt.clf()
        for x_history, y_history, h_history, color in zip(x_histories, y_histories, h_histories, colors):
            
            # print(color)

            plt.plot(
                x_history[:i],
                y_history[:i],
                c=color,
                marker=".",
                alpha=0.5,
                label="vehicle trajectory",
            )

            if i < len(x_history):
                plot_roomba(x_history[i-1], y_history[i-1], h_history[i-1], color, radius)
            else:
                plot_roomba(x_history[-1], y_history[-1], h_history[-1], color, radius)

        
        plt.ylim(0, y_range[1])
        plt.xlim(0,x_range[1])
        plt.axis("equal")
        # plt.axis("off")
        

        plt.tight_layout()

        plt.grid(False)
        plt.title(title)

        plt.draw()
        # plt.savefig(f"frames/sim_{i}.png")
        # plt.show()
        plt.pause(0.2)
    print("press enter to continue")
    input()
    plt.close()
    
def plot_roomba(x, y, yaw, color, radius=.7):
    """

    Args:
        x ():
        y ():
        yaw ():
    """
    fig = plt.gcf()
    ax = fig.gca()
    circle = plt.Circle((x, y), radius, color=color, fill=False)
    ax.add_patch(circle)

    # Plot direction marker
    dx = radius * np.cos(yaw)
    dy = radius * np.sin(yaw)
    ax.arrow(x, y, dx, dy, head_width=0.1, head_length=0.05, fc='r', ec='r')

def generate_prob_from_db(N, lib, cp_dist=-.5, sigma=.8):
    

    d = lib.key_to_idx

    num_robots = 0
    while num_robots != 4:
        # get a random key from the library
        key, idx = random.choice(list(d.items()))

        # print(key)
        # print(len(key))

        num_robots = len(key) // 4

    start_nodes = []
    goal_nodes = []

    for i in range(0, len(key), 4):
        start = [int(key[i]), int(key[i+1])]
        goal = [int(key[i+2]), int(key[i+3])]
        
        start_nodes.append(start)
        goal_nodes.append(goal)


    sol = lib.get_matching_solution(start_nodes, goal_nodes)

    # print(f"sol = {sol}")

    # turn this solution into an initial guess 
    initial_guess_state = np.zeros((num_robots*3, N+1))
    initial_guess_control = np.zeros((num_robots*2, N))

    # the initial guess for time is the length of the longest path in the solution
    initial_guess_T = 2*max([len(sol[i]) for i in range(num_robots)])

    for i in range(num_robots):

        # print(f"Robot {i+1} solution:")
        rough_points = np.array(sol[i])
        points = []
        for point in rough_points:
            if point[0] == -1: break
            points.append(point)
        
        points = np.array(points)
        smoothed_curve, _ = smooth_path(points, N+1, cp_dist, sigma)

        initial_guess_state[i*3, :] = smoothed_curve[:, 0]        # x
        initial_guess_state[i*3 + 1, :] = smoothed_curve[:, 1]    # y

        headings = calculate_headings(smoothed_curve)
        headings.append(headings[-1])

        initial_guess_state[i*3 + 2, :] = headings

    return start_nodes, goal_nodes, initial_guess_state, initial_guess_T


def calculate_initial_guess(state, h, N, num_robots, initial_guess_T):
    for i in range(num_robots):
        initial_guess_state[i*3, :] = state[i*3, :]*h + .5*h
        initial_guess_state[i*3+1, :] = state[i*3+1, :]*h + .5*h

    # calculate initial guess for velocities and omegas
    # velocity is the change in position divided by the time step
    initial_guess_control = np.zeros((num_robots*2, N))
    dt = initial_guess_T / N
    change_in_position = []
    for i in range(num_robots):
        x = initial_guess_state[i*3, :]         # x
        y = initial_guess_state[i*3 + 1, :]    # y


        change_in_position = []
        for j in range(len(x)-1):
            pos1 = np.array([x[j], y[j]])
            pos2 = np.array([x[j+1], y[j+1]])

            change_in_position.append(np.linalg.norm(pos2 - pos1))

        velocity = np.array(change_in_position) / dt
        initial_guess_control[i*2, :] = velocity

        # omega is the difference between consecutive headings
        headings = initial_guess_state[i*3 + 2, :]
        omega = np.diff(headings)
        initial_guess_control[i*2 + 1, :] = omega

    return initial_guess_state, initial_guess_control


if __name__ == "__main__":

    import os
    import numpy as np
    import random

    # -------------------- SET VARIABLES -------------------- #

    # load the yaml file
    import yaml
    with open("guided_mrmp/tests/initial_guesses.yaml") as file:
        settings = yaml.load(file, Loader=yaml.FullLoader)

    # seed = 1
    seed = 112358
    print(f"***Setting Python Seed {seed}***")
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    random.seed(seed)

    # define obstacles
    circle_obs = np.array(settings['environment']['circle_obs'])
    rectangle_obs = np.array(settings['environment']['rectangle_obs'])
    

    # weights for the cost function
    dist_robots_weight = settings['cost_weights']['dist_robots_weight']
    dist_obstacles_weight = settings['cost_weights']['dist_obstacles_weight']
    control_costs_weight = settings['cost_weights']['control_costs_weight']
    time_weight = settings['cost_weights']['time_weight']
    goal_weight = settings['cost_weights']['goal_weight']

    # other params
    rob_radius = settings['robot_radius']
    N = settings['N']

    from guided_mrmp.utils import Library
    import random
    lib_name = settings['library']['name']
    lib = Library("guided_mrmp/database/"+lib_name+"_library")
    lib.read_library_from_file()

    cp_dist = float(settings['control_point_distance'])
    num_trials = settings['num_trials']

    h = settings['grid_resolution']
    x_max = settings['library']['x_max']
    y_max = settings['library']['y_max']
    x_range = (0, x_max*h)
    y_range = (0, y_max*h)

    # -------------------- GENERATE PROBLEM AND INITIAL GUESS -------------------- #
    robot_starts_xy, robot_goals, initial_guess_state, initial_guess_T = generate_prob_from_db(N,lib, cp_dist)

    num_robots = len(robot_starts_xy)
    initial_guess_state, initial_guess_control = calculate_initial_guess(initial_guess_state, h, N, num_robots, initial_guess_T)

    # update the starts and goals to align with grid resolution
    robot_starts_xy = np.array(robot_starts_xy)
    robot_goals = np.array(robot_goals)
    robot_starts_xy = robot_starts_xy*h + .5*h
    robot_goals = robot_goals*h + .5*h

    
    initial_guesses = {
        'X': initial_guess_state,
        'U': initial_guess_control,
        'T': initial_guess_T
        }

    initial_guess_type = settings['initial_guess_type']
    
    if initial_guess_type == 'line':
        initial_guess = np.zeros((num_robots*3,N+1))
        for i in range(0,num_robots*3,3):
            start=robot_starts_xy[int(i/3)]
            goal=robot_goals[int(i/3)]
            initial_guess[i,:] = np.linspace(start[0], goal[0], N+1)
            initial_guess[i+1,:] = np.linspace(start[1], goal[1], N+1)

            # make the heading initial guess the difference between consecutive points
            for j in range(N):
                dx = initial_guess[i,j+1] - initial_guess[i,j]
                dy = initial_guess[i+1,j+1] - initial_guess[i+1,j]
                initial_guess[i+2,j] = np.arctan2(dy,dx)
        
        initial_guesses = {
        'X': initial_guess,
        'T': settings['initial_guess']['T']
        }

    elif initial_guess_type == 'None':
        initial_guesses = None

    # set the starts to include the heading
    robot_starts = []
    for i in range(num_robots):
        # print the robot's start position and print the robots first position in the initial guess
        # print(f"Robot start = {robot_starts_xy[i]}")
        # print(f"Robot initial guess x,y,heading= {initial_guess_state[i*3,0]}, {initial_guess_state[i*3+1,0]}, {initial_guess_state[i*3+2,0]}")
        robot_starts.append([robot_starts_xy[i][0], robot_starts_xy[i][1], initial_guess_state[i*3+2,0]])


    # -------------------- SOLVE THE PROBLEM -------------------- #
    solver = TrajOptResolver(num_robots=num_robots, 
                            robot_radius=rob_radius,
                            starts=robot_starts, 
                            goals=robot_goals, 
                            circle_obstacles=circle_obs, 
                            rectangle_obstacles=rectangle_obs,
                            rob_dist_weight=dist_robots_weight,
                            obs_dist_weight=dist_obstacles_weight,
                            control_weight=control_costs_weight,
                            time_weight=time_weight,
                            goal_weight=goal_weight,
                            conflicts = None,
                            all_robots = None
                            )
    

    # Visualize the initial guess
    # plot_paths_db(circle_obs, num_robots, robot_starts, robot_goals, None, initial_guess_state, x_range, y_range, rob_radius, "Initial Guess")
    # xs = initial_guess_state[::3]
    # ys = initial_guess_state[1::3]
    # thetas = initial_guess_state[2::3]
    # plot_sim(xs, ys, thetas, x_range, y_range, rob_radius, "Initial Guess") 

    solver_options = {'ipopt.print_level': settings['solver_options']['print_level'],
                        'print_time': settings['solver_options']['print_time'],
                        'ipopt.warm_start_init_point': 'yes'}
    
    import time
    start = time.time()
    old_sol, sol,pos, vels, omegas, xs, ys, thetas, T = solver.solve(N, x_range, y_range, initial_guesses, solver_options)
    end = time.time()




    # plot_paths_db(circle_obs, num_robots, robot_starts, robot_goals, pos, None, x_range, y_range, rob_radius, "Optimizer solution")
    # plot_sim(xs, ys, thetas, x_range, y_range, rob_radius, "Optimizer solution")

    print(f"Time to solve (1st time)= {end-start}")
    print(sol.stats()["iter_count"])


    if sol is None:
        print("failed to find a solution")

    else:
        # Try solving the problem again with the solution as a seed
        initial_guess_X = np.zeros((num_robots*3, N+1))
        for i in range(num_robots):
            initial_guess_X[i*3, :] = xs[i,:]
            initial_guess_X[i*3+1, :] = ys[i,:]
            initial_guess_X[i*3+2, :] = thetas[i,:]

        initial_guess_U = np.zeros((num_robots*2, N))
        for i in range(num_robots):
            initial_guess_U[i*2, :] = vels[i,:]
            initial_guess_U[i*2+1, :] = omegas[i,:]

        initial_guesses = {
            'X': initial_guess_X,
            'U': initial_guess_U,
            'T': T
            }
        

        solver = TrajOptResolver(num_robots=num_robots, 
                        robot_radius=rob_radius,
                        starts=robot_starts, 
                        goals=robot_goals, 
                        circle_obstacles=circle_obs, 
                        rectangle_obstacles=rectangle_obs,
                        rob_dist_weight=dist_robots_weight,
                        obs_dist_weight=dist_obstacles_weight,
                        control_weight=control_costs_weight,
                        time_weight=time_weight,
                        goal_weight=goal_weight,
                        conflicts=None,
                        all_robots=None
                        )

        solver_options = {'ipopt.print_level': settings['solver_options']['print_level'],
                        'print_time': settings['solver_options']['print_time'],
                        'ipopt.acceptable_tol': 1000, 
                        'ipopt.acceptable_iter': 10,
                        'ipopt.dual_inf_tol':1000,
                        'ipopt.compl_inf_tol':1000,}

        import time
        start = time.time()
        lam_g,sol,pos, vels, omegas, xs, ys, thetas, T = solver.solve(N, x_range, y_range, initial_guesses, solver_options, old_sol)

        end = time.time()

    
        print(f"Time to solve with old solution as a seed = {end-start}")
        print(sol.stats()["iter_count"])


        pos_vals = np.array(sol.value(pos))
        
        plot_paths_db(circle_obs, num_robots, robot_starts, robot_goals, None, initial_guess_state, x_range, y_range, rob_radius )
        # plot_paths_db(circle_obs, num_robots, robot_starts, robot_goals, pos_vals, None, x_range, y_range, rob_radius)
        plot_sim(xs, ys, thetas, x_range, y_range, rob_radius)


        # plot the solution for all the robots side by side with the initial guess
        #  create a 2x2 figure where the top left is the position (x,y, heading) of the robots
        # the top right is the control (velocity, omega) of the robots
        # the bottom left is the position of the robots in the initial guess and the 
        # bottom right is the control of the robots in the initial guess
        fig, axs = plt.subplots(2, 3, figsize=(12, 18))

        # create a color map for the robots
        colors = plt.cm.Set1(np.linspace(0, 1, num_robots))

        # loop over color, xs, ys, thetas, vels, omegas
        for i,(x,y,heading, vel, omega, color) in enumerate(zip(xs, ys, thetas, vels, omegas, colors)):
            axs[0,0].plot(x, y, color=color)
            axs[0,0].scatter(x, y, color=color)

            for x0,y0,heading0 in zip(x,y,heading):
                dx = 0.1 * np.cos(heading0)
                dy = 0.1 * np.sin(heading0)
                axs[0,0].arrow(x0, y0, dx, dy, head_width=0.05, head_length=0.1, fc='green', ec=color)

            axs[0,0].set_title("Robot Positions")
            axs[0,0].set_xlabel("X")
            axs[0,0].set_ylabel("Y")

            axs[0,1].plot(vel, label=f"Robot {i}", color=color)
            axs[0,2].plot(omega, label=f"Robot {i}", color=color)
            axs[0,1].set_title("Velocity")
            axs[0,2].set_title("Omega")
            axs[0,1].set_xlabel("Time")
            axs[0,1].set_ylabel("Control")
            axs[0,1].legend()

        for i, (x,y,heading, vel, omega, color) in enumerate(zip(initial_guess_X[::3], initial_guess_X[1::3], initial_guess_X[2::3], initial_guess_U[::2], initial_guess_U[1::2], colors)):
            axs[1,0].plot(x, y, color=color)
            axs[1,0].scatter(x, y, color=color)
            for x0,y0,heading0 in zip(x,y,heading):
                dx = 0.1 * np.cos(heading0)
                dy = 0.1 * np.sin(heading0)
                axs[1,0].arrow(x0, y0, dx, dy, head_width=0.05, head_length=0.1, fc=color, ec=color)

            axs[1,0].set_title("Initial Guess Robot Positions")
            axs[1,0].set_xlabel("X")
            axs[1,0].set_ylabel("Y")



            axs[1,1].plot(vel, label=f"Robot {i}", color=color)
            axs[1,2].plot(omega, label=f"Robot {i}", color=color)
            axs[1,1].set_title("Initial Guess Velocity")
            axs[1,2].set_title("Initial Guess Omega")
            axs[1,1].set_xlabel("Time")
            axs[1,1].set_ylabel("Control")
            axs[1,1].legend()

        plt.show()