import numpy as np
from scipy.interpolate import interp1d


def compute_path_from_wp(start_xp, start_yp, step=0.1):
    """
    params:
        start_xp (array-like): 1D array of x-positions
        start_yp (array-like): 1D array of y-positions
        step (float): interpolation step size 

    output:
        ndarray of shape (3,N) representing the  path as x,y,heading
    """
    final_xp = []
    final_yp = []
    delta = step  # [m]
    for idx in range(len(start_xp) - 1):
        # find the distance between consecutive waypoints
        section_len = np.sum(
            np.sqrt(
                np.power(np.diff(start_xp[idx : idx + 2]), 2)
                + np.power(np.diff(start_yp[idx : idx + 2]), 2)
            )
        )

        # how many interpolated points are needed to reach the next waypoint
        interp_range = np.linspace(0, 1, np.floor(section_len / delta).astype(int))  

        # interpolate between waypoints
        fx = interp1d(np.linspace(0, 1, 2), start_xp[idx : idx + 2], kind=1)
        fy = interp1d(np.linspace(0, 1, 2), start_yp[idx : idx + 2], kind=1)

        # append the interpolated points to the final path 
        final_xp = np.append(final_xp, fx(interp_range)[1:])
        final_yp = np.append(final_yp, fy(interp_range)[1:])
    dx = np.append(0, np.diff(final_xp))
    dy = np.append(0, np.diff(final_yp))
    theta = np.arctan2(dy, dx)
    return np.vstack((final_xp, final_yp, theta))

def get_nn_idx(state, path, visited=[]):
    """
    Helper function to find the index of the nearest path point to the current state.

    The "nearest" point is defined as the point with the smallest Euclidean distance 
    to the current state that has not already been visited.

    Args:
        state (array-like): Current state [x, y, theta]
        path (ndarray): Path points [[x1, y1], [x2, y2], ...]
        visited (array-like): Visited path points [[x1, y1], [x2, y2], ...]

    Returns:
        int: Index of the nearest path point
    """
    # Calculate the Euclidean distance between the current state and all path points
    distances = np.linalg.norm(path[:2] - state[:2].reshape(2, 1), axis=0)
    
    # Set the distance to infinity for visited points
    for point in visited:
        point = np.array(point)
        # print(f"point = {point}")
        # print(f"path[:2] = {path[:2]}")
        # Set the distance to infinity for visited points
        distances = np.where(np.linalg.norm(path[:2] - point.reshape(2, 1), axis=0) < 1e-3, np.inf, distances)
    
    return np.argmin(distances)

def fix_angle_reference(angle_ref, angle_init):
    """
    Removes jumps greater than 2PI to smooth the heading.

    Args:
        angle_ref (array-like): Reference angles
        angle_init (float): Initial angle

    Returns:
        array-like: Smoothed reference angles
    """
    diff_angle = angle_ref - angle_init
    diff_angle = np.unwrap(diff_angle)
    return angle_init + diff_angle

def get_ref_trajectory(state, path, target_v, T, DT, path_visited_points=[]):
    """
    Generates a reference trajectory for the Roomba.

    Args:
        state (array-like): Current state [x, y, theta]
        path (ndarray): Path points [x, y, theta] in the global frame
        path_visited_points (array-like): Visited path points [[x, y], [x, y], ...]
        target_v (float): Desired speed
        T (float): Control horizon duration
        DT (float): Control horizon time-step

    Returns:
        ndarray: Reference trajectory [x_k, y_k, theta_k] in the ego frame
    """
    K = int(T / DT)

    xref = np.zeros((3, K))  # Reference trajectory for [x, y, theta]

    # find the nearest path point to the current state
    ind = get_nn_idx(state, path, path_visited_points)

    path_visited_points.append([path[0, ind], path[1, ind]])

    # calculate the cumulative distance along the path
    cdist = np.append([0.0], np.cumsum(np.hypot(np.diff(path[0, :]), np.diff(path[1, :]))))
    cdist = np.clip(cdist, cdist[0], cdist[-1])

    # determine where we want the robot to be at each time step 
    start_dist = cdist[ind]
    interp_points = [d * DT * target_v + start_dist for d in range(1, K + 1)]

    # interpolate between these points to get the reference trajectory
    xref[0, :] = np.interp(interp_points, cdist, path[0, :])
    xref[1, :] = np.interp(interp_points, cdist, path[1, :])
    xref[2, :] = np.interp(interp_points, cdist, path[2, :])
    
    # Transform to ego frame
    dx = xref[0, :] - state[0]
    dy = xref[1, :] - state[1]
    xref[0, :] = dx * np.cos(-state[2]) - dy * np.sin(-state[2])  # X
    xref[1, :] = dy * np.cos(-state[2]) + dx * np.sin(-state[2])  # Y
    xref[2, :] = path[2, ind] - state[2]  # Theta

    # Normalize the angles
    xref[2, :] = (xref[2, :] + np.pi) % (2.0 * np.pi) - np.pi
    xref[2, :] = fix_angle_reference(xref[2, :], xref[2, 0])

    return xref, path_visited_points