import numpy as np
import matplotlib.pyplot as plt
from guided_mrmp.utils import Library
import sys

from guided_mrmp.conflict_resolvers import TrajOptMultiRobot

# Function to calculate the Bézier curve points
def bezier_curve(t, control_points):
    P0, P1, P2 = control_points
    return (1 - t)**2 * P0 + 2 * (1 - t) * t * P1 + t**2 * P2

def smooth_path(points, control_point_distance):
    # List to store the points along the smoothed curve
    smoothed_curve = []

    # Connect the first point to the first control point
    # control_point_start = points[0] + (points[1] - points[0]) * control_point_distance
    smoothed_curve.append(points[0])
    # smoothed_curve.append(control_point_start)

    # Iterate through each set of three consecutive points
    for i in range(len(points) - 2):
        # Extract the three consecutive points
        P0 = points[i]
        P1 = points[i + 1]
        P2 = points[i + 2]
        
        # Calculate the tangent directions at the start and end points
        tangent_start = (P1 - P0) / np.linalg.norm(P1 - P0)
        tangent_end = (P2 - P1) / np.linalg.norm(P2 - P1)
        
        # Calculate the control points
        control_point_start = P1 - tangent_start * control_point_distance
        control_point_end = P1 + tangent_end * control_point_distance
        
        # Construct the Bézier curve for the current set of points
        control_points = [control_point_start, P1, control_point_end]
        t_values = np.linspace(0, 1, 20)
        print(t_values)
        curve_points = np.array([bezier_curve(t, control_points) for t in t_values])
        
        # Append the points along the curve to the smoothed curve list
        smoothed_curve.extend(curve_points[1:])

    # Connect the last control point to the last point
    # control_point_end = points[-1] - (points[-1] - points[-2]) * control_point_distance
    # smoothed_curve.append(control_point_end)
    smoothed_curve.append(points[-1])

    # Convert smoothed curve points to a numpy array
    return np.array(smoothed_curve)


if __name__ == "__main__":

    # define obstacles
    circle_obs = np.array([])

    rectangle_obs = np.array([])

    # points1 = np.array([[1,6],
    #           [1,1],
    #           [9,1]])
    
    # points2 = np.array([[9,1],
    #           [9,6],
    #           [1,6]])

    # smoothed_curve1 = smooth_path(points1, 3)
    # smoothed_curve2 = smooth_path(points2, 3)

    # # Plot the original points and the smoothed curve
    # plt.plot(points1[:, 0], points1[:, 1], 'bo-', label='original path')
    # plt.plot(smoothed_curve1[:, 0], smoothed_curve1[:, 1], 'r-', label='curved path')
    # plt.xlabel('X')
    # plt.ylabel('Y')
    # # plt.title('Smoothed Curve using Bézier Curves')
    # plt.legend()
    # plt.grid(True)
    # plt.axis('equal')
    # plt.show()

    # Example points
    lib = Library("guided_mrmp/database/2x3_library")
    lib.read_library_from_file()

    robot_starts = [[0, 0], [0, 2], [1, 2]]
    robot_goals = [[0, 1],[1, 2], [0, 2]]
    sol = lib.get_matching_solution(robot_starts, robot_goals)

    print(sol)

    for points in sol:


        # Condition to filter out rows equal to [-1, -1]
        points = np.array(points)
        
        condition = (points != [-1, -1]).any(axis=1)
        points = points[condition]
        print(f"points = {points}")

        # Parameters
        control_point_distance = 0.3  # Distance of control points from the middle point

        smoothed_curve = smooth_path(points, control_point_distance)
        print(f"smoothed_curve = {smoothed_curve}")

        # Plot the original points and the smoothed curve
        plt.plot(points[:, 0], points[:, 1], 'bo-', label='original path')
        plt.plot(smoothed_curve[:, 0], smoothed_curve[:, 1], 'r-', label='curved path')
        plt.xlabel('X')
        plt.ylabel('Y')
        # plt.title('Smoothed Curve using Bézier Curves')
        plt.legend()
        plt.grid(True)
        plt.axis('equal')
        plt.show()


    # weights for the cost function
    dist_robots_weight = 10
    dist_obstacles_weight = 10
    control_costs_weight = 1.0
    time_weight = 5.0

    # other params
    num_robots = 3
    rob_radius = 0.25
    N = 20


    # # initial guess 
    # print(f"N = {N}")
    # initial_guess = np.zeros((num_robots*3,N+1))
    # print(initial_guess)
    # # for i,(start,goal) in enumerate(zip(robot_starts, robot_goals)):
    # for i in range(0,num_robots*2,3):
    #     start=robot_starts[int(i/2)]
    #     goal=robot_goals[int(i/2)]
    #     initial_guess[i,:] = np.linspace(start[0], goal[0], N+1)
    #     initial_guess[i+1,:] = np.linspace(start[1], goal[1], N+1)
    #     # initial_guess[i+2,:] = np.linspace(.5, .5, N+1)
    #     # initial_guess[i+3,:] = np.linspace(.5, .5, N+1)

    # print(initial_guess)

    

    solver = TrajOptMultiRobot(num_robots=num_robots, 
                               robot_radius=rob_radius,
                               starts=robot_starts, 
                               goals=robot_goals, 
                               circle_obstacles=circle_obs, 
                               rectangle_obstacles=rectangle_obs,
                               rob_dist_weight=dist_robots_weight,
                               obs_dist_weight=dist_obstacles_weight,
                               control_weight=control_costs_weight,
                               time_weight=time_weight
                               )
    sol,pos = solver.solve(N, initial_guess)
    pos_vals = np.array(sol.value(pos))

    
    solver.plot_paths(pos_vals, initial_guess)