import cvxpy as cp
import numpy as np
import time

def place_grid(robot_locations, cell_size, grid_size=5, subgoals=[], obstacles=[]):
    """
        Place a grid to cover robot locations with alignment to centers.

        inputs:
            - robot_locations (list): locations of robots involved in conflict [[x,y], [x,y], ...]
            - cell_size (float): the width of each grid cell in continuous space
            - grid_size (tuple): width of the grid in cells
            - obstacles (list): locations of circular obstacles [[x,y,r], [x,y,r], ...]
        outputs:
            - origin (tuple): bottom-left corner of the grid in continuous space
            - cell_centers (list): centers of grid cells for each robot (same order as robot_locations)
    """
    start_time = time.time()
    
    robot_locations = np.array(robot_locations)
    subgoals = np.array(subgoals)
    obstacles = np.array(obstacles)
    num_robots = len(robot_locations)
    num_obst = len(obstacles)
    
    M_ind = 10 * grid_size # Big M relative to grid indices
    M_cts = 10 * max(max(robot_locations[:,0]) - min(robot_locations[:,0]), max(robot_locations[:,1]) - min(robot_locations[:,1])) # Big M relative to robot locations
    
    # Decision variable: Bottom-left corner of the grid in continuous space
    bottom_left = cp.Variable(2, name='origin')
    # Defin top right for convenience
    top_right = bottom_left + grid_size * cell_size
    
    # Decision variable: Integer grid indices for each robot
    grid_indices = cp.Variable((num_robots, 2), integer=True, name='grid_indices')
    
    # Calculate cell centers for each robot based on grid indices
    # Reshape origin to (1, 2) for broadcasting
    cell_centers = cp.reshape(bottom_left, (1, 2), order='C') + grid_indices * cell_size + cell_size / 2
    
    if obstacles:
        obstacle_cells = cp.Variable((num_obst, grid_size, grid_size), boolean=True, name='obstacle_cells')
    
    # Objective: Minimize the sum of squared distances
    if obstacles:
        alpha = 0.10 # Weight
        cost = (1 - alpha) * cp.sum_squares(robot_locations - cell_centers) + alpha * cp.sum(obstacle_cells)
    else:
        cost = cp.sum_squares(robot_locations - cell_centers)
    
    # Constraints
    constraints = []
    
    # Grid indices must be non-negative
    constraints.append(grid_indices >= 0)
    
    # Grid indices must fit within grid bounds
    constraints.append(grid_indices <= grid_size - 1)
    
    # No two robots can share a cell
    # Use Big M method to ensure unique grid indices
    for i in range(num_robots):
        for j in range(i+1, num_robots):
            # At least one of the two constraints below must be true
            xsep = cp.Variable(boolean=True)
            ysep = cp.Variable(boolean=True)
            constraints.append(xsep + ysep >= 1)
            
            # Enforces separation by at least 1 in the x direction
            b0 = cp.Variable(boolean=True) # b0 = 0 if robot i's x >= robot j's x, 1 otherwise
            # b0 = 0
            constraints.append(robot_locations[j, 0] - robot_locations[i, 0] <= M_cts * b0)
            constraints.append(grid_indices[i, 0] - grid_indices[j, 0] + M_ind * b0 + M_ind * (1 - xsep) >= 1)
            # b0 = 1
            constraints.append(robot_locations[i, 0] - robot_locations[j, 0] <= M_cts * (1 - b0))
            constraints.append(grid_indices[j, 0] - grid_indices[i, 0] + M_ind * (1 - b0) + M_ind * (1 - xsep) >= 1)
            
            # Enforces separation by at least 1 in the y direction
            b1 = cp.Variable(boolean=True) # b1 = 0 if robot i's y >= robot j's y, 1 otherwise
            # b1 = 0
            constraints.append(robot_locations[j, 1] - robot_locations[i, 1] <= M_cts * b1)
            constraints.append(grid_indices[i, 1] - grid_indices[j, 1] + M_ind * b1 + M_ind * (1 - ysep) >= 1)
            # b1 = 1
            constraints.append(robot_locations[i, 1] - robot_locations[j, 1] <= M_cts * (1 - b1))
            constraints.append(grid_indices[j, 1] - grid_indices[i, 1] + M_ind * (1 - b1) + M_ind * (1 - ysep) >= 1)
    
    # All robots and subgoals must be within grid bounds
    for loc in robot_locations:
        constraints.append(bottom_left <= loc)
        constraints.append(loc <= top_right)
    for sg in subgoals:
        constraints.append(bottom_left <= sg)
        constraints.append(sg <= top_right)

    if obstacles:
        M = 100
        for i, (cx, cy, r) in enumerate(obstacles):
            for x_idx in range(grid_size):
                for y_idx in range(grid_size):
                    cell_bot_left = bottom_left + np.array([x_idx, y_idx]) * cell_size # Bottom left point of grid cell at (x_idx, y_idx)
                    cell_top_right = cell_bot_left + np.ones(2) * cell_size  # Top right point
                    
                    # Is the obstacle center within the x and y ranges of the cell
                    inside_x = cp.Variable(boolean=True)
                    constraints.append(cx >= cell_bot_left[0] - M * (1 - inside_x))
                    constraints.append(cx <= cell_top_right[0] + M * (1 - inside_x))
                    constraints.append(cx <= cell_bot_left[0] + M * inside_x)
                    constraints.append(cx >= cell_top_right[0] - M * inside_x)
    
                    inside_y = cp.Variable(boolean=True)
                    constraints.append(cy >= cell_bot_left[1] - M * (1 - inside_y))
                    constraints.append(cy <= cell_top_right[1] + M * (1 - inside_y))
                    constraints.append(cy <= cell_bot_left[1] + M * inside_y)
                    constraints.append(cy >= cell_top_right[1] - M * inside_y)
                    
                    # Checking if the obstacles lie over the x coordinates of vertical edges, y coords of horizontal edges
                    intersects_bottom_y = cp.Variable(boolean=True)
                    intersects_top_y = cp.Variable(boolean=True)
                    intersects_left_x = cp.Variable(boolean=True)
                    intersects_right_x = cp.Variable(boolean=True)
                    constraints.append(cell_bot_left[1] - cy <= r + M * (1 - intersects_bottom_y))
                    constraints.append(cy - cell_top_right[1] <= r + M * (1 - intersects_top_y))
                    constraints.append(cell_bot_left[0] - cx <= r + M * (1 - intersects_left_x))
                    constraints.append(cx - cell_top_right[0] <= r + M * (1 - intersects_right_x))
                    
                    # Does the obstacle intersect the cell edges
                    intersects_bottom_edge = cp.Variable(boolean=True)
                    intersects_top_edge = cp.Variable(boolean=True)
                    intersects_left_edge = cp.Variable(boolean=True)
                    intersects_right_edge = cp.Variable(boolean=True)
                    constraints.append(intersects_bottom_y + inside_x >= 2 - 3 * (1 - intersects_bottom_edge))
                    constraints.append(intersects_top_y + inside_x >= 2 - 3 * (1 - intersects_top_edge))
                    constraints.append(intersects_left_x + inside_y >= 2 - 3 * (1 - intersects_left_edge))
                    constraints.append(intersects_right_x + inside_y >= 2 - 3 * (1 - intersects_right_edge))
                    
                    inside_x_and_y_ranges = cp.Variable(boolean=True)
                    constraints.append(inside_x + inside_y >= 2 - 3 * (1 - inside_x_and_y_ranges))
                    
                    circle_intersects_edge = cp.Variable(boolean=True)
                    constraints.append(intersects_bottom_edge + intersects_top_edge + intersects_left_edge + intersects_right_edge >= 1 - circle_intersects_edge)
                    
                    # Obstacles intersects with the cell if 
                    #   the center is within the cell OR
                    #   the obstacles intersects a cell edge
                    constraints.append(obstacle_cells[i,x_idx,y_idx] >= inside_x_and_y_ranges)
                    constraints.append(obstacle_cells[i,x_idx,y_idx] >= circle_intersects_edge)
                    constraints.append(obstacle_cells[i,x_idx,y_idx] <= inside_x_and_y_ranges + circle_intersects_edge)
    
    # Solve the optimization problem
    prob_init_start_time = time.time()
    prob = cp.Problem(cp.Minimize(cost), constraints)
    solve_start_time = time.time()
    prob.solve(solver=cp.SCIP, verbose=True)
    solve_end_time = time.time()
    
    print("Time to add vars/constraints:", prob_init_start_time - start_time)
    print("Time to parse:", solve_start_time - prob_init_start_time)
    print("Time to solve:", solve_end_time - solve_start_time)

    if prob.status != "optimal":
        print("Problem could not be solved to optimality.")
        return None
    
    return bottom_left.value, cell_centers.value


# Working on making this convex
def two_corner_place_grid(robot_locations, grid_size=5, subgoals=[], obstacles=[]):
    """
        Place a grid to cover robot locations with alignment to centers.

        inputs:
            - robot_locations (list): locations of robots involved in conflict [[x,y], [x,y], ...]
            - cell_size (float): the width of each grid cell in continuous space
            - grid_size (tuple): width of the grid in cells
            - obstacles (list): locations of circular obstacles [[x,y,r], [x,y,r], ...]
        outputs:
            - origin (tuple): bottom-left corner of the grid in continuous space
            - cell_centers (list): centers of grid cells for each robot (same order as robot_locations)
    """
    start_time = time.time()
    
    robot_locations = np.array(robot_locations)
    subgoals = np.array(subgoals)
    obstacles = np.array(obstacles)
    N = len(robot_locations)
    
    # Decision variable: Bottom-left corner of the grid in continuous space
    bottom_left = cp.Variable(2, name='bottom_left')
    top_right = cp.Variable(2, name='top_right')
    
    # Bottom-right and top-left corners of the grid for convenience
    # bottom_right = 0.5 * cp.hstack([bottom_left[0] + top_right[0] - bottom_left[1] + top_right[1],
    #                                 bottom_left[0] - top_right[0] + bottom_left[1] + top_right[1]])
    # top_left = 0.5 * cp.hstack([bottom_left[0] + top_right[0] + bottom_left[1] - top_right[1],
    #                             -bottom_left[0] + top_right[0] + bottom_left[1] + top_right[1]])
    bottom_right = cp.Variable(2, name='bottom_right')
    top_left = cp.Variable(2, name='top_left')
    
    grid_x_hat = cp.Variable(2, name='grid_x_hat')
    grid_y_hat = cp.Variable(2, name='grid_y_hat')
    
    # Decision variable: Integer grid indices for each robot
    grid_indices = cp.Variable((N, 2), integer=True, name='grid_indices')
    
    # Calculate cell centers for each robot based on grid indices
    # Reshape origin to (1, 2) for broadcasting
    grid_x_offsets = cp.Variable((N, 2), name='grid_x_offsets')
    grid_y_offsets = cp.Variable((N, 2), name='grid_y_offsets')
    cell_centers = cp.reshape(bottom_left, (1, 2), order='C') + grid_x_offsets + grid_y_offsets
    
    # Objective: Minimize the sum of squared distances
    cost = cp.sum_squares(robot_locations - cell_centers)
    
    # Constraints
    constraints = []
    
    # Ensure top-right and bottom-left corners are in the right orientation
    constraints.append(top_right >= bottom_left)
    
    # Fixing bottom-right and top-left corners
    constraints.append(2 * bottom_right[0] == bottom_left[0] + top_right[0] - bottom_left[1] + top_right[1])
    constraints.append(2 * bottom_right[1] == bottom_left[0] - top_right[0] + bottom_left[1] + top_right[1])
    constraints.append(2 * top_left[0] == bottom_left[0] + top_right[0] + bottom_left[1] - top_right[1])
    constraints.append(2 * top_left[1] == -bottom_left[0] + top_right[0] + bottom_left[1] + top_right[1])
    
    # Defining grid_x_hat and grid_y_hat based on corners
    constraints.append(grid_x_hat == (bottom_right - bottom_left) * (1 / grid_size))
    constraints.append(grid_y_hat == (top_left - bottom_left) * (1 / grid_size))
    
    # Defining offsets in cell centers calculation
    constraints.append(grid_x_offsets == grid_x_hat * grid_indices)
    
    # Grid indices must be non-negative
    constraints.append(grid_indices >= 0)
    
    # Grid indices must fit within grid bounds
    constraints.append(grid_indices <= grid_size - 1)
    
    # No two robots can share a cell
    # Use Big M method to ensure unique grid indices
    M_ind = 10 * grid_size # Big M relative to grid indices
    M_cts = 10 * max(max(robot_locations[:,0]) - min(robot_locations[:,0]), max(robot_locations[:,1]) - min(robot_locations[:,1])) # Big M relative to robot locations
    for i in range(N):
        for j in range(i+1, N):
            # At least one of the two constraints below must be true
            xsep = cp.Variable(boolean=True)
            ysep = cp.Variable(boolean=True)
            constraints.append(xsep + ysep >= 1)
            
            # Enforces separation by at least 1 in the x direction
            b0 = cp.Variable(boolean=True) # b0 = 0 if robot i's x >= robot j's x, 1 otherwise
            # b0 = 0
            constraints.append(robot_locations[j, 0] - robot_locations[i, 0] <= M_cts * b0)
            constraints.append(grid_indices[i, 0] - grid_indices[j, 0] + M_ind * b0 + M_ind * (1 - xsep) >= 1)
            # b0 = 1
            constraints.append(robot_locations[i, 0] - robot_locations[j, 0] <= M_cts * (1 - b0))
            constraints.append(grid_indices[j, 0] - grid_indices[i, 0] + M_ind * (1 - b0) + M_ind * (1 - xsep) >= 1)
            
            # Enforces separation by at least 1 in the y direction
            b1 = cp.Variable(boolean=True) # b1 = 0 if robot i's y >= robot j's y, 1 otherwise
            # b1 = 0
            constraints.append(robot_locations[j, 1] - robot_locations[i, 1] <= M_cts * b1)
            constraints.append(grid_indices[i, 1] - grid_indices[j, 1] + M_ind * b1 + M_ind * (1 - ysep) >= 1)
            # b1 = 1
            constraints.append(robot_locations[i, 1] - robot_locations[j, 1] <= M_cts * (1 - b1))
            constraints.append(grid_indices[j, 1] - grid_indices[i, 1] + M_ind * (1 - b1) + M_ind * (1 - ysep) >= 1)
    
    # Solve the optimization problem
    prob_init_start_time = time.time()
    prob = cp.Problem(cp.Minimize(cost), constraints)
    solve_start_time = time.time()
    prob.solve(solver=cp.SCIP)
    solve_end_time = time.time()
    
    print("Time to add vars/constraints:", prob_init_start_time - start_time)
    print("Time to parse:", solve_start_time - prob_init_start_time)
    print("Time to solve:", solve_end_time - solve_start_time)

    if prob.status != "optimal":
        print("Problem could not be solved to optimality.")
        return None
    
    print("Grid Indices:", grid_indices.value)
    
    return bottom_left.value, cell_centers.value


def mccormick_envelope(w, x, xl, xu, y, yl, yu):
    """
    Generates McCormick envelope constraints
    """
    mec = []
    mec.append(w >= xl*y + x*yl - xl*yl)
    mec.append(w >= xu*y + x*yu - xu*yu)
    mec.append(w <= xu*y + x*yl - xu*yl)
    mec.append(w >= x*yu + xl*y - xl*yu)
    return mec


def plot_grid(bottom_left, top_right, grid_size):
    import matplotlib.pyplot as plt
    
    bottom_left = np.array(bottom_left)
    top_right = np.array(top_right)
    bottom_right = np.array([bottom_left[0] + top_right[0] - bottom_left[1] + top_right[1],
                             bottom_left[0] - top_right[0] + bottom_left[1] + top_right[1]]) / 2
    top_left = np.array([bottom_left[0] + top_right[0] + bottom_left[1] - top_right[1],
                         -bottom_left[0] + top_right[0] + bottom_left[1] + top_right[1]]) / 2
    
    x_prime_hat = (bottom_right - bottom_left) / grid_size
    y_prime_hat = (top_left - bottom_left) / grid_size
    
    # Draw the grid
    for i in range(grid_size + 1):
        # Draw vertical lines
        plt.plot([(bottom_left + i * x_prime_hat)[0], (top_left + i * x_prime_hat)[0]], 
                 [(bottom_left + i * x_prime_hat)[1], (top_left + i * x_prime_hat)[1]], 'k-')
        # Draw horizontal lines
        plt.plot([(bottom_left + i * y_prime_hat)[0], (bottom_right + i * y_prime_hat)[0]], 
                 [(bottom_left + i * y_prime_hat)[1], (bottom_right + i * y_prime_hat)[1]], 'k-')


def get_roomba_locs(low, high, num_robots, radius=0.5, obstacles=[]):
    """
    Generates a list of roomba locations within the box bounded by points (low, low), (high, low), (high, high), (low, high).
    The roombas must be separated by at least 2 * radius
    """
    locs = []
    while len(locs) < num_robots:
        locs.append(np.random.uniform(low, high, 2))
        invalid = False
        for (obst_x, obst_y, obst_r) in obstacles:
            if np.linalg.norm(np.array(locs[-1]) - np.array([obst_x, obst_y])) <= radius + obst_r:
                invalid = True
                break         
        for other_loc in locs[:-1]:
            if np.linalg.norm(np.array(locs[-1]) - np.array(other_loc)) <= 2 * radius:
                invalid = True
                break
        if invalid:
            locs = locs[:-1]
    return np.array(locs)


def main(seed, num_robots, plot, two_corner):
    if seed is not None:
        np.random.seed(seed)
    
    if not two_corner:
        roomba_radius = 0.5
        cell_size = 2.5 * roomba_radius
        grid_size = 5
        
        
        obstacles = np.array([[2, 2, 0.75], [4, 4, 0.5]])
        # robot_locations = np.random.uniform(low=0, high=5, size=(num_robots, 2))
        robot_locations = get_roomba_locs(low=0, high=6, num_robots=num_robots, radius=roomba_radius, obstacles=obstacles)
        # subgoals = np.array([[0, 0], [0, 6], [6, 6], [6, 0]])
        subgoals = get_roomba_locs(low=0, high=6, num_robots=num_robots, radius=roomba_radius, obstacles=obstacles)
        
        # bottom_left, cell_centers = place_grid(robot_locations=robot_locations, 
        #                                        cell_size=cell_size,
        #                                        grid_size=grid_size,
        #                                        subgoals=subgoals)
        bottom_left, cell_centers = place_grid(robot_locations=robot_locations, 
                                               cell_size=cell_size,
                                               grid_size=grid_size,
                                               subgoals=subgoals,
                                               obstacles=obstacles)
        
        print("Grid Origin (Bottom-Left Corner):", bottom_left)
        print("Cell Centers:", cell_centers)
        
        top_right = np.array(bottom_left) + grid_size * cell_size
    else:
        grid_size = 5
        robot_locations = np.random.uniform(low=0, high=5, size=(num_robots, 2))
        print("Robot Locations:", robot_locations)
        
        bottom_left, top_right, grid_indices = two_corner_place_grid(robot_locations, grid_size)
        print("Grid Bottom-Left Corner:", bottom_left)
        print("Grid Top-Right Corner:", top_right)
        print("Grid Indices:", grid_indices)
    
    if plot:
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        
        fig, ax = plt.subplots()
        
        plot_grid(bottom_left, top_right, grid_size=grid_size)
        
        # Plot cell centers
        cell_centers = np.array(cell_centers)
        plt.scatter(cell_centers[:, 0], cell_centers[:, 1], c='b', label='Cell Centers')
        for center in cell_centers:
            square = patches.Rectangle(center - cell_size/2, cell_size, cell_size, edgecolor='b', facecolor='b', alpha=0.2, linewidth=2)
            ax.add_patch(square)
        
        # Plot robot locations
        robot_locations = np.array(robot_locations)
        plt.scatter(robot_locations[:, 0], robot_locations[:, 1], c='r', label='Robot Locations')
        for (x, y) in robot_locations:
            circle = patches.Circle((x, y), radius=roomba_radius, edgecolor='r', fill=False, linewidth=2)
            ax.add_patch(circle)
            
        if not two_corner:
            subgoals = np.array(subgoals)
            plt.scatter(subgoals[:, 0], subgoals[:, 1], c='orange', marker='^', label='Subgoals')
            for (x, y) in subgoals:
                circle = patches.Circle((x, y), radius=roomba_radius, edgecolor='orange', fill=False, linewidth=2)
                ax.add_patch(circle)
    
            obstacles = np.array(obstacles)
            plt.scatter(obstacles[:, 0], obstacles[:, 1], c='black', marker='s', label='Obstacles')
            for (x, y, r) in obstacles:
                circle = patches.Circle((x, y), radius=r, edgecolor='black', fill=False, linewidth=2)
                ax.add_patch(circle)
        
        plt.legend(loc='upper left')
        
        ax.set_aspect('equal')

        plt.show()

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--seed", 
        type=int, 
        default=None
    )
    parser.add_argument(
        "--num_robots", 
        type=int, 
        default=3
    )
    parser.add_argument(
        "--plot", 
        action='store_true'
    )
    parser.add_argument(
        "--two_corner", 
        action='store_true'
    )
    args = parser.parse_args()

    main(args.seed, args.num_robots, args.plot, args.two_corner)