import cvxpy as cp
import numpy as np

def place_grid(robot_locations, cell_size=1, grid_shape=(5, 5), return_loss=False):
    """
        Place a grid to cover robot locations with alignment to centers.

        inputs:
            - robot_locations (list): locations of robots involved in conflict [[x,y], [x,y], ...]
            - cell_size (float): the width of each grid cell in continuous space
            - grid_shape (tuple): (# of rows, # of columns) of the grid
        outputs:
            - origin (tuple): bottom-left corner of the grid in continuous space
            - cell_centers (list): centers of grid cells for each robot (same order as robot_locations)
            - loss: when return_loss=True, sum of squared differences loss
    """
    robot_locations = np.array(robot_locations)
    N = len(robot_locations)
    
    # Decision variable: Bottom-left corner of the grid in continuous space
    origin = cp.Variable(2, name='origin')
    
    # Decision variable: Integer grid indices for each robot
    grid_indices = cp.Variable((N, 2), integer=True, name='grid_indices')
    
    # Calculate cell centers for each robot based on grid indices
    # Reshape origin to (1, 2) for broadcasting
    cell_centers = cp.reshape(origin, (1, 2), order='C') + grid_indices * cell_size + cell_size / 2
    
    # Objective: Minimize the sum of squared distances
    cost = cp.sum_squares(robot_locations - cell_centers)
    
    # Constraints
    constraints = []
    
    # Grid indices must be non-negative
    constraints.append(grid_indices >= 0)
    
    # Grid indices must fit within grid bounds
    if grid_shape[0] == grid_shape[1]: # Square grid
        constraints.append(grid_indices <= grid_shape[0] - 1)
    else: # Rectangular grid
        constraints.append(grid_indices[:,0] <= grid_shape[1] - 1)
        constraints.append(grid_indices[:,1] <= grid_shape[0] - 1)
    
    # No two robots can share a cell
    # Use Big M method to ensure unique grid indices
    M = max(grid_shape) * 10
    for i in range(N):
        for j in range(i+1, N):
            # At least one of the two constraints below must be true
            y1 = cp.Variable(boolean=True)
            y2 = cp.Variable(boolean=True)
            constraints.append(y1 + y2 >= 1)
            
            # Enforces separation by at least 1 in the x direction
            if robot_locations[i, 0] >= robot_locations[j, 0]:
                constraints.append(grid_indices[i, 0] - grid_indices[j, 0] + M * (1 - y1) >= 1)
            else:
                constraints.append(grid_indices[j, 0] - grid_indices[i, 0] + M * (1 - y1) >= 1)
            
            # Enforces separation by at least 1 in the y direction
            if robot_locations[i, 1] >= robot_locations[j, 1]:
                constraints.append(grid_indices[i, 1] - grid_indices[j, 1] + M * (1 - y2) >= 1)
            else:
                constraints.append(grid_indices[j, 1] - grid_indices[i, 1] + M * (1 - y2) >= 1)
    
    # Solve the optimization problem
    prob = cp.Problem(cp.Minimize(cost), constraints)
    prob.solve(solver=cp.SCIP)

    if prob.status not in ["optimal", "optimal_inaccurate"]:
        print("Problem could not be solved to optimality.")
        return None
    
    if return_loss:
        return origin.value, cell_centers.value, prob.value
    return origin.value, cell_centers.value


# This currently does not follow DCP, working on it
def place_grid_with_rotation(robot_locations, grid_size=5, return_loss=False):
    """
        Place a square grid to cover robot locations with alignment to centers. Allows for rotation and scaling of the grid.

        inputs:
            - robot_locations (list): locations of robots involved in conflict [[x,y], [x,y], ...]
            - grid_size (float): the number of cells in each row/column of the grid
        outputs:
            - bottom_left (tuple): bottom-left corner of the grid in continuous space
            - top_right (tuple): top-right corner of the grid in continuous space
            - cell_centers (list): centers of grid cells for each robot (same order as robot_locations)
            - loss: when return_loss=True, sum of squared differences loss
    """
    robot_locations = np.array(robot_locations)
    N = len(robot_locations)
    
    # Decision variables: Bottom-left and top-right corners of the grid in continuous space
    bottom_left = cp.Variable(2, name='bottom_left')
    top_right = cp.Variable(2, name='top_right')
    
    # Decision variable: Integer grid indices for each robot
    grid_indices = cp.Variable((N, 2), integer=True, name='grid_indices')

    # Define bottom-right and top-left corners of the grid
    bottom_right = (1 / 2) * cp.hstack([
        bottom_left[0] + top_right[0] - bottom_left[1] + top_right[1],
        bottom_left[0] - top_right[0] + bottom_left[1] + top_right[1]
    ])
    top_left = (1 / 2) * cp.hstack([
        bottom_left[0] + top_right[0] + bottom_left[1] - top_right[1],
        -bottom_left[0] + top_right[0] + bottom_left[1] + top_right[1]
    ])
    
    # Define basis vectors for the grid
    # Vector pointing from **left -> right** on the grid, with length equal to the width of one cell (1 grid index)
    x_prime_hat = (bottom_right - bottom_left) * (1 / grid_size)
    # Vector pointing from **bottom -> top** on the grid, with length equal to the width of one cell (1 grid index)
    y_prime_hat = (top_left - bottom_left) * (1 / grid_size)
    
    # Calculate cell centers for each robot based on grid indices
    cell_centers = cp.vstack([bottom_left for _ in range(N)]) # Grid origin point
    cell_centers += cp.vstack([x_prime_hat * (grid_indices[i,0] + 0.5) for i in range(N)]) # Component of cell centers in the x_prime direction
    cell_centers += cp.vstack([y_prime_hat * (grid_indices[i,1] + 0.5) for i in range(N)]) # Component of cell centers in the y_prime direction
    print(cell_centers)
    
    # Objective: Minimize the sum of squared distances
    cost = cp.sum_squares(robot_locations - cell_centers)
    
    # Initialize constraints
    constraints = []
    
    # The top right corner of the grid can't be below or to the left of the bottom left corner
    constraints.append(top_right >= bottom_left)
    
    # Grid indices must be non-negative
    constraints.append(grid_indices >= 0)
    
    # Grid indices must fit within grid bounds
    constraints.append(grid_indices <= grid_size - 1)
    
    # No two robots can share a cell
    M = 1e6 # Sufficiently large constant for Big-M
    for i in range(N):
        for j in range(i+1, N):
            x_separated = cp.Variable(boolean=True)
            y_separated = cp.Variable(boolean=True)
            
            # Robot i's coordinate in the x_prime direction
            robot_i_x_prime = robot_locations[i] @ x_prime_hat
            # Robot j's coordinate in the x_prime direction
            robot_j_x_prime = robot_locations[j] @ x_prime_hat
            
            b0 = cp.Variable(boolean=True)
            # When b1 = 1, robot i's x_prime coordinate is greater than robot j's
            constraints.append(robot_i_x_prime - robot_j_x_prime <= M * b0)  
            # When b1 = 0, robot j's x_prime coordinate is greater than robot i's
            constraints.append(robot_j_x_prime - robot_i_x_prime <= M * (1 - b0)) 
            # Enforces separation by at least 1 between the robots' x indices on the grid
            constraints.append((2 * b0 - 1) * (grid_indices[i, 0] - grid_indices[j, 0]) + M * (1 - x_separated) >= 1)
            
            # Robot i's coordinate in the y_prime direction
            robot_i_y_prime = robot_locations[i] @ y_prime_hat
            # Robot j's coordinate in the y_prime direction
            robot_j_y_prime = robot_locations[j] @ y_prime_hat
            
            b1 = cp.Variable(boolean=True)
            # When b1 = 1, robot i's y_prime coordinate is greater than robot j's
            constraints.append(robot_i_y_prime - robot_j_y_prime <= M * b1)  
            # When b1 = 0, robot j's y_prime coordinate is greater than robot i's
            constraints.append(robot_j_y_prime - robot_i_y_prime <= M * (1 - b1)) 
            # Enforces separation by at least 1 between the robots' y indices on the grid
            constraints.append((2 * b1 - 1) * (grid_indices[i, 1] - grid_indices[j, 1]) + M * (1 - y_separated) >= 1)
                
            # Robots must be separated in at least one of the x, y directions
            constraints.append(x_separated + y_separated >= 1)
    
    # Solve the optimization problem
    prob = cp.Problem(cp.Minimize(cost), constraints)
    prob.solve(solver=cp.SCIP)

    if prob.status not in ["optimal", "optimal_inaccurate"]:
        print("Problem could not be solved to optimality.")
        return None
    
    if return_loss:
        return bottom_left.value, top_right.value, cell_centers.value, prob.value
    return bottom_left.value, top_right.value, cell_centers.value


def plot_grid(bottom_left, top_right, grid_size):
    import matplotlib.pyplot as plt
    
    bottom_left = np.array(bottom_left)
    top_right = np.array(top_right)
    
    # Draw the grid
    for i in range(grid_shape[1] + 1):
        # Draw vertical lines
        plt.plot([origin[0] + i * cell_size, origin[0] + i * cell_size], 
                    [origin[1], origin[1] + grid_shape[0] * cell_size], 'k-')
    for i in range(grid_shape[0] + 1):
        # Draw horizontal lines
        plt.plot([origin[0], origin[0] + grid_shape[1] * cell_size], 
                    [origin[1] + i * cell_size, origin[1] + i * cell_size], 'k-')


def main(allow_rotations, num_robots, num_angles, seed):
    if seed is not None:
        np.random.seed(seed)
    
    robot_locations = np.random.uniform(low=0, high=5, size=(num_robots, 2))
    cell_size = 1
    grid_shape = (5, 5)
    
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(6, 6))
    
    if allow_rotations:
        grid_center, cell_centers, rotation_angle = place_grid_with_rotation(robot_locations, cell_size, grid_shape, num_angles=num_angles)
        print("Grid Center:", grid_center)
        # angles = angle_info.keys()
        # losses = [loss for (_, _, loss) in angle_info.values()]
        # plt.scatter(angles, losses, c='r')
        
        plot_grid_with_rotation(grid_center, cell_size, grid_shape, rotation_angle)
    else:
        origin, cell_centers = place_grid(robot_locations, cell_size, grid_shape)
        print("Grid Origin (Bottom-Left Corner):", origin)
        
        plot_grid(origin, cell_size, grid_shape)
    
    # Plot robot locations
    robot_locations = np.array(robot_locations)
    plt.scatter(robot_locations[:, 0], robot_locations[:, 1], c='r', label='Robot Locations')

    # Plot cell centers
    cell_centers = np.array(cell_centers)
    plt.scatter(cell_centers[:, 0], cell_centers[:, 1], c='b', label='Cell Centers')
    
    plt.legend(loc='upper left')

    plt.show()

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--allow_rotations", 
        type=bool, 
        default=False
    )
    parser.add_argument(
        "--num_robots", 
        type=int, 
        default=2
    )
    parser.add_argument(
        "--num_angles", 
        type=int, 
        default=18
    )
    parser.add_argument(
        "--seed", 
        type=int, 
        default=None
    )
    args = parser.parse_args()

    main(args.allow_rotations, args.num_robots, args.num_angles, args.seed)