import numpy as np
from scipy.interpolate import interp1d


def compute_path_from_wp(start_xp, start_yp, step=0.1):
    """
    params:
        start_xp (array-like): 1D array of x-positions
        start_yp (array-like): 1D array of y-positions
        step (float): interpolation step size 

    output:
        ndarray of shape (3,N) representing the  path as x,y,heading
    """
    final_xp = []
    final_yp = []
    delta = step  # [m]
    for idx in range(len(start_xp) - 1):
        # find the distance between consecutive waypoints
        section_len = np.sum(
            np.sqrt(
                np.power(np.diff(start_xp[idx : idx + 2]), 2)
                + np.power(np.diff(start_yp[idx : idx + 2]), 2)
            )
        )

        # how many interpolated points are needed to reach the next waypoint
        interp_range = np.linspace(0, 1, np.floor(section_len / delta).astype(int))  

        # interpolate between waypoints
        fx = interp1d(np.linspace(0, 1, 2), start_xp[idx : idx + 2], kind=1)
        fy = interp1d(np.linspace(0, 1, 2), start_yp[idx : idx + 2], kind=1)

        # append the interpolated points to the final path 
        final_xp = np.append(final_xp, fx(interp_range)[1:])
        final_yp = np.append(final_yp, fy(interp_range)[1:])
    dx = np.append(0, np.diff(final_xp))
    dy = np.append(0, np.diff(final_yp))
    theta = np.arctan2(dy, dx)
    return np.vstack((final_xp, final_yp, theta))

def get_nn_idx(state, path):
    """
    Helper function to find the index of the nearest path point to the current state.
    Args:
        state (array-like): Current state [x, y, theta]
        path (ndarray): Path points

    Returns:
        int: Index of the nearest path point
    """
    # distances = np.hypot(path[0, :] - state[0], path[1, :] - state[1])
    distances = np.linalg.norm(path[:2]-state[:2].reshape(2,1), axis=0)
    return np.argmin(distances)

def fix_angle_reference(angle_ref, angle_init):
    """
    Removes jumps greater than 2PI to smooth the heading.

    Args:
        angle_ref (array-like): Reference angles
        angle_init (float): Initial angle

    Returns:
        array-like: Smoothed reference angles
    """
    diff_angle = angle_ref - angle_init
    diff_angle = np.unwrap(diff_angle)
    return angle_init + diff_angle

def get_ref_trajectory(state, path, target_v, T, DT):
    """
    Generates a reference trajectory for the Roomba.

    Args:
        state (array-like): Current state [x, y, theta]
        path (ndarray): Path points [x, y, theta] in the global frame
        target_v (float): Desired speed
        T (float): Control horizon duration
        DT (float): Control horizon time-step

    Returns:
        ndarray: Reference trajectory [x_k, y_k, theta_k] in the ego frame
    """
    K = int(T / DT)

    xref = np.zeros((3, K))  # Reference trajectory for [x, y, theta]

    # find the nearest path point to the current state
    ind = get_nn_idx(state, path)

    # calculate the cumulative distance along the path
    cdist = np.append([0.0], np.cumsum(np.hypot(np.diff(path[0, :]), np.diff(path[1, :]))))
    cdist = np.clip(cdist, cdist[0], cdist[-1])

    # determine where we want the robot to be at each time step 
    start_dist = cdist[ind]
    interp_points = [d * DT * target_v + start_dist for d in range(1, K + 1)]

    # interpolate between these points to get the reference trajectory
    xref[0, :] = np.interp(interp_points, cdist, path[0, :])
    xref[1, :] = np.interp(interp_points, cdist, path[1, :])
    xref[2, :] = np.interp(interp_points, cdist, path[2, :])
    
    # Transform to ego frame
    dx = xref[0, :] - state[0]
    dy = xref[1, :] - state[1]
    xref[0, :] = dx * np.cos(-state[2]) - dy * np.sin(-state[2])  # X
    xref[1, :] = dy * np.cos(-state[2]) + dx * np.sin(-state[2])  # Y
    xref[2, :] = path[2, ind] - state[2]  # Theta

    # Normalize the angles
    xref[2, :] = (xref[2, :] + np.pi) % (2.0 * np.pi) - np.pi
    xref[2, :] = fix_angle_reference(xref[2, :], xref[2, 0])

    return xref