import numpy as np
import matplotlib.pyplot as plt
import math
from scipy.ndimage import gaussian_filter1d

def bezier(t, points):
    """Calculate Bezier curve point for parameter t and given control points."""
    n = len(points) - 1
    return sum(
        (math.comb(n, i) * (1 - t) ** (n - i) * t ** i * points[i] for i in range(n + 1)),
        np.zeros(2)
    )

def smooth_path(points, N=100, alpha=0.25, sigma=1.0):
    smooth_points = []
    control_points = []
    i = 0
    while i < len(points) - 1:
        if i < len(points) - 3 and is_bend(points[i:i+4]):
            # Double bend (cubic bezier with two softened control points)
            p0, p1, p2, p3 = np.array(points[i]), np.array(points[i+1]), np.array(points[i+2]), np.array(points[i+3])
            cp1 = soften_control_point(p1, alpha, p0, p2)  # Soften the first control point
            cp2 = soften_control_point(p2, alpha, p1, p3)  # Soften the second control point
            control_points.extend([cp1, cp2])  # Collect control points for visualization
            for t in np.linspace(0, 1, 20):
                smooth_points.append(bezier(t, [p0, cp1, cp2, p3]))
            i += 3
        elif i < len(points) - 2 and is_bend(points[i:i+3]):
            # Single bend (quadratic bezier with one softened control point)
            p0, p1, p2 = np.array(points[i]), np.array(points[i+1]), np.array(points[i+2])
            cp = soften_control_point(p1, alpha, p0, p2)  # Use refined softening
            control_points.append(cp)  # Collect control point for visualization
            for t in np.linspace(0, 1, 20):
                smooth_points.append(bezier(t, [p0, cp, p2]))
            i += 2
        else:
            # No bend, interpolate straight line

            if points[i][0] == points[i+1][0] and points[i][1] == points[i+1][1]:
                for t in range(20):
                    smooth_points.append(points[i])
                i += 1
                continue

            for t in np.linspace(0, 1, 20):
                smooth_points.append((1 - t) * points[i] + t * points[i + 1])
            i += 1

    # Ensure start and end points are included
    smooth_points = [points[0]] + smooth_points + [points[-1]]

    # Apply Gaussian smoothing to soften remaining sharp transitions
    smooth_points = np.array(smooth_points)
    smooth_points[:, 0] = gaussian_filter1d(smooth_points[:, 0], sigma=sigma)
    smooth_points[:, 1] = gaussian_filter1d(smooth_points[:, 1], sigma=sigma)

    # Downsample to N points, preserving start and end points
    indices = np.linspace(0, len(smooth_points) - 1, N).astype(int)
    downsampled_points = smooth_points[indices]
    downsampled_points[0], downsampled_points[-1] = points[0], points[-1]

    return downsampled_points, np.array(control_points)

def soften_control_point(middle_point, alpha, prev_point, next_point):
    """Move middle point along the bisector away from the 90-degree angle."""

    middle_point = np.array(middle_point, dtype=np.float64)
    prev_point = np.array(prev_point, dtype=np.float64)
    next_point = np.array(next_point, dtype=np.float64)

    # Vectors from middle point to adjacent points
    vec1 = prev_point - middle_point
    vec2 = next_point - middle_point

    # Normalize the vectors
    vec1 /= np.linalg.norm(vec1)
    vec2 /= np.linalg.norm(vec2)

    # Calculate the bisector direction
    bisector = vec1 + vec2
    bisector /= np.linalg.norm(bisector)  # Normalize bisector

    # Move middle point along the bisector direction
    adjusted_point = middle_point + alpha * bisector
    return adjusted_point

def is_bend(segment):
    """Check if three or four points form a 90-degree bend."""
    if len(segment) == 3:
        return np.cross(segment[1] - segment[0], segment[2] - segment[1]) != 0
    elif len(segment) == 4:
        return (np.cross(segment[1] - segment[0], segment[2] - segment[1]) != 0 and
                np.cross(segment[2] - segment[1], segment[3] - segment[2]) != 0)
    return False

def calculate_headings(path_points):
    """
    Calculate headings for each segment in the path, allowing for reverse movement.

    Parameters:
        path_points (np.ndarray): Array of (x, y) points representing the smoothed path.

    Returns:
        headings (list): List of headings (in radians) for each segment in the path.
    """
    headings = []
    prev_heading = None

    for i in range(len(path_points) - 1):
        # Calculate forward and reverse headings for each segment
        p1, p2 = path_points[i], path_points[i + 1]
        forward_heading = np.arctan2(p2[1] - p1[1], p2[0] - p1[0])

        # print(f"p1 = {p1}, p2 = {p2}")

        if prev_heading is None:
            chosen_heading = forward_heading

        elif p2[0] == p1[0] and p2[1] == p1[1]:
            # print("same")
            chosen_heading = prev_heading

        

        

        # Choose direction based on previous heading to minimize angle change
        else:
            reverse_heading = (forward_heading + np.pi) % (2 * np.pi)
            forward_diff = np.abs((forward_heading - prev_heading + np.pi) % (2 * np.pi) - np.pi)
            reverse_diff = np.abs((reverse_heading - prev_heading + np.pi) % (2 * np.pi) - np.pi)



            chosen_heading = forward_heading if forward_diff <= reverse_diff else reverse_heading
       


        # plot the two points and the heading
        # import matplotlib.pyplot as plt
        # plt.plot([p1[0], p2[0]], [p1[1], p2[1]], 'ro-')  # Plot the two points
        # # dx = 0.1 * np.cos(forward_heading)
        # # dy = 0.1 * np.sin(forward_heading)
        # # plt.arrow(p1[0], p1[1], dx, dy, head_width=0.01, head_length=0.1, fc='blue', ec='blue')
        # dx = 0.1 * np.cos(chosen_heading)
        # dy = 0.1 * np.sin(chosen_heading)
        # plt.arrow(p1[0], p1[1], dx, dy, head_width=0.01, head_length=0.1, fc='green', ec='green')
        # plt.show()

        headings.append(chosen_heading)
        prev_heading = chosen_heading

    return headings

if __name__ == "__main__":
    # Example points and visualization
    points = np.array([
        [0, 0], [0, 1], [1, 1], [2, 1], [2, 2], [2, 3], [2, 2], [2, 1], [2,0]
    ])

    # points = np.array([
    #     [0, 0], [0, 1], [1, 1], [1,0], [0,0]
    # ])

    # points = np.array([
        # [0, 0], [0, 1], [1, 1], [2, 1], [2, 2], [2, 3], [3, 3], [3, 2], [2,2], [2,3]
    # ])

    points = np.array([
        [3,3], [3,5], [1,5]])

    # Generate smooth curve and get control points
    N = 40
    smooth_points, control_points = smooth_path(points, N, alpha=-.5, sigma=0.8)

    print(f"smooth_points = {smooth_points}")

    # Example usage with a smooth path
    # path_points, control_points = smooth_path(points, N=100, alpha=0.2, sigma=1.5)
    headings = calculate_headings(smooth_points)

    # Displaying the headings
    for i, heading in enumerate(headings):
        print(f"Segment {i}: Heading = {np.degrees(heading):.2f} degrees")

    # Plotting
    plt.figure(figsize=(8, 8))
    plt.plot(smooth_points[:, 0], smooth_points[:, 1], 'b-', label="Bezier Smooth Path")

    plt.scatter(points[:, 0], points[:, 1], color="purple", marker="x", s=100, label="Control Points")

    # Add circles and headings
    for i, (x, y) in enumerate(smooth_points):
        plt.plot(x, y, 'ro')  # Circle at each point
        if i < len(headings):
            heading = headings[i]
            dx = 0.1 * np.cos(heading)
            dy = 0.1 * np.sin(heading)
            plt.arrow(x, y, dx, dy, head_width=0.05, head_length=0.1, fc='green', ec='green')

    plt.xlabel("X")
    plt.ylabel("Y")
    plt.title("Smoothed Path")
    plt.legend()
    plt.grid(True)
    plt.show()