import cv2
import numpy as np
import math

# global show_cv because I didn't want to have show_cv as an input to every function
show_cv = None
def init_show_cv(val):
    global show_cv
    show_cv = val

def find_longest_lines(img):
    gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray_img, 50, 100, apertureSize=3)

    if (show_cv):
        cv2.imshow('Canny Filter', edges)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    theta_thresh = 60
    horizontal_lines = cv2.HoughLines(edges, rho=1, theta=np.pi/180, threshold=60, min_theta=(theta_thresh/2-1)*np.pi/theta_thresh, max_theta=(theta_thresh/2+1)*np.pi/theta_thresh)
    vertical_lines = cv2.HoughLines(edges, rho=1, theta=np.pi/180, threshold=60, min_theta=-np.pi/theta_thresh, max_theta=np.pi/theta_thresh)

    vertical_line_points = convert_to_cartesian(vertical_lines)
    horizontal_line_points = convert_to_cartesian(horizontal_lines)

    # filter lines too close to each other
    filtered_vertical = filter_lines(vertical_line_points, 50)
    filtered_horizontal = filter_lines(horizontal_line_points, 50)

    # get the 9 largest lines
    sorted_vertical = sorted(filtered_vertical, key=lambda line: min(line[0][1], line[0][3]))[:9]
    sorted_horizontal = sorted(filtered_horizontal, key=lambda line: min(line[0][0], line[0][2]))[:9]

    return sorted_vertical, sorted_horizontal

def convert_to_cartesian(lines):
    line_points = []
    if lines is not None:
        for line in lines:
            rho, theta = line[0]

            cos_theta  = np.cos(theta)
            sin_theta = np.sin(theta)

            x0 = cos_theta * rho
            y0 = sin_theta * rho

            x1 = int(x0 + 1000 * (-sin_theta))
            y1 = int(y0 + 1000 * (cos_theta))
            x2 = int(x0 - 1000 * (-sin_theta))
            y2 = int(y0 - 1000 * (cos_theta))

            line_points.append([[x1,y1,x2,y2]])
    return line_points

def filter_lines(lines, min_distance):
    filtered_lines = []

    # filter out lines too close to each other
    # (this assumes lines are around the same size and parallel)
    # (extremely simplified to improve computational speed because this is all we need)
    if lines is not None:
        for line1 in lines:
            x1, y1, x2, y2 = line1[0]
            line1_x_avg = (x1 + x2) / 2
            line1_y_avg = (y1 + y2) / 2
            keep_line = True
            for line2 in filtered_lines:
                x3, y3, x4, y4 = line2[0]
                line2_x_avg = (x3 + x4) / 2
                line2_y_avg = (y3 + y4) / 2

                # calculate dist between average points of the 2 lines
                dist = np.sqrt((line1_x_avg - line2_x_avg)**2 + (line1_y_avg - line2_y_avg)**2)

                if dist < min_distance:
                    keep_line = False
                    break

            if keep_line:
                filtered_lines.append(line1)
            
    return filtered_lines

def find_board_and_pieces(img):
    vertical_lines, horizontal_lines = find_longest_lines(img)
    print("# of Vertical:",len(vertical_lines))
    print("# of Horizontal:",len(horizontal_lines))

    if (len(vertical_lines) != 9 or len(horizontal_lines) != 9):
        print("Error: Grid does not match expected 9x9")
        return

    # create bitmasks for vert and horiz so we can get lines and intersections
    height, width, _ = img.shape
    vertical_mask = np.zeros((height, width), dtype=np.uint8)
    horizontal_mask = np.zeros((height, width), dtype=np.uint8)

    for line in vertical_lines:
        x1, y1, x2, y2 = line[0]
        cv2.line(vertical_mask, (x1, y1), (x2, y2), (255), 2)

    for line in horizontal_lines:
        x1, y1, x2, y2 = line[0]
        cv2.line(horizontal_mask, (x1, y1), (x2, y2), (255), 2)

    # get lines and intersections of grid and corresponding contours
    intersection = cv2.bitwise_and(vertical_mask, horizontal_mask)
    board_lines = cv2.bitwise_or(vertical_mask, horizontal_mask)

    contours, hierarchy = cv2.findContours(board_lines, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    intersections, hierarchy = cv2.findContours(intersection, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # find midpoints of intersection contours (this gives us exact coordinates of the corners of the grid)
    intersection_points = []
    for contour in intersections:
        M = cv2.moments(contour)
        if (M["m00"] != 0):
            midpoint_x = int(M["m10"] / M["m00"])
            midpoint_y = int(M["m01"] / M["m00"])
            intersection_points.append((midpoint_x, midpoint_y))

    # sort the coordinates from left to right then top to bottom
    sorted_intersection_points = sort_square_grid_coords(intersection_points, unpacked=False)

    if (show_cv):
        board_lines_img = img.copy()
        cv2.drawContours(board_lines_img, contours, -1, (255, 255, 0), 2)
        # cv2.drawContours(board_lines_img, intersections, -1, (0, 0, 255), 2)
        i = 0
        for points in sorted_intersection_points:
            for point in points:
                cv2.circle(board_lines_img, point, 5, (255 - (3 * i), (i % 9) * 28, 3 * i), -1)
                i += 1
                # print(i)
        cv2.imshow('Lines of Board', board_lines_img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    # find largest contour and get rid of it because it contains weird edges from lines
    max_area = 100000 # we're assuming board is going to be big (hopefully to speed up computation on raspberry pi)
    largest = -1
    # second_largest = -1
    # max_rect = None
    for i, contour in enumerate(contours):
        area = cv2.contourArea(contour)
        if area > max_area:
            max_area = area
            largest = i
    # "largest" is index of largest contour

    # get rid of contour containing the edges of the lines   
    contours = list(contours)
    contours.pop(largest)
    contours = tuple(contours)

    # thicken lines so that connections are made
    contour_mask = np.zeros((height, width), dtype=np.uint8)
    cv2.drawContours(contour_mask, contours, -1, (255), thickness=10)
    thick_contours, _ = cv2.findContours(contour_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # obtain largest contour of the thickened lines (the border) and approximate a 4 sided polygon onto it
    max_area = 100000
    largest = -1
    max_rect = None
    for i, contour in enumerate(thick_contours):
        area = cv2.contourArea(contour)
        if area > max_area:
            epsilon = 0.05 * cv2.arcLength(contour, True)
            rect = cv2.approxPolyDP(contour, epsilon, True) # uses Douglas-Peucker algorithm (probably overkill)
            if (len(rect) == 4):
                max_area = area
                largest = i
                max_rect = rect

    # perspective transform based on rectangle outline of board
    corners = max_rect.reshape(-1, 2) # reshapes it so each row has 2 elements
    corners = [tuple(corner) for corner in corners] # convert to tuples
    print(corners)
    # corners.sort(key=lambda coord: (coord[0], coord[1])) # sort coords. goes from bottom left clockwise to bottom right - DIDN'T WORK
    corners_sorted = sort_square_grid_coords(corners, unpacked=True)

    tl = corners_sorted[0]
    tr = corners_sorted[1]
    bl = corners_sorted[2]
    br = corners_sorted[3]
    src = np.float32([list(tl), list(tr), list(bl), list(br)])
    dest = np.float32([[0,0], [width, 0], [0, height], [width, height]])
    M = cv2.getPerspectiveTransform(src, dest)
    Minv = cv2.getPerspectiveTransform(dest, src)
    warped_img = img.copy()
    warped_img = cv2.warpPerspective(np.uint8(warped_img), M, (width, height))

    M = cv2.getPerspectiveTransform(src, dest)
    Minv = cv2.getPerspectiveTransform(dest, src)
    warped_ip = intersection.copy() # warped intersection points
    # warped_ip = cv2.drawContours(warped_ip, intersections, -1, (0, 0, 255), 2)
    warped_ip = cv2.warpPerspective(np.uint8(warped_ip), M, (width, height))

    # ----
    intersections, hierarchy = cv2.findContours(warped_ip, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # find midpoints of intersection contours (this gives us exact coordinates of the corners of the grid)
    intersection_points = []
    for contour in intersections:
        M = cv2.moments(contour)
        if (M["m00"] != 0):
            midpoint_x = int(M["m10"] / M["m00"])
            midpoint_y = int(M["m01"] / M["m00"])
            intersection_points.append((midpoint_x, midpoint_y))

    # sort the coordinates from left to right then top to bottom
    sorted_intersection_points = sort_square_grid_coords(intersection_points, unpacked=False)
    # ----

    if (show_cv):
        contours_img = img.copy()
        # for i in range(63):
        #     cv2.drawContours(contours_img, [sorted_contours[i]], -1, (255-4*i, 4*i, 0), 2)
        cv2.drawContours(contours_img, thick_contours, -1, (0, 255, 0), 2)
        cv2.drawContours(contours_img, [thick_contours[largest]], -1, (0, 0, 255), 2)
        cv2.drawContours(contours_img, [max_rect], -1, (255, 0, 0), 2)
        for x,y in corners:
            cv2.circle(contours_img, (x, y), 5, (0, 255, 255), -1)
        # cv2.circle(contours_img, (int(min_x), int(min_y)), 5, (255, 0, 0), -1)
        # cv2.circle(contours_img, (int(max_x), int(max_y)), 5, (255, 0, 0), -1)
        cv2.imshow('Contours', contours_img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

        # cv2.imshow('Warped', warped_img)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()

    # COLOR / PIECE DETECTION ------------------------------------------------

    hsv_img = cv2.cvtColor(warped_img, cv2.COLOR_BGR2HSV)

    hsv_mask_sat = cv2.inRange(hsv_img[:,:,1], 100, 255) # saturation mask
    hsv_mask_bright = cv2.inRange(hsv_img[:,:,2], 100, 255) # brightness mask

    # Combine the saturation and brightness masks
    hsv_mask = cv2.bitwise_and(hsv_mask_sat, hsv_mask_bright)

    # Apply the mask to the entire HSV image
    hsv_after = cv2.bitwise_and(hsv_img, hsv_img, mask=hsv_mask)

    # contours, _ = cv2.findContours(hsv_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    # filtered_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > 20]
    # contour_mask = np.zeros_like(hsv_mask)
    # cv2.drawContours(contour_mask, filtered_contours, -1, (255), thickness=cv2.FILLED)





    # create color histogram for each square and find candidate color, if any
    hue_thresh_dict = {'red': (170,190), 'orange':(11,30), 'yellow': (31,40), 'green': (50,70), 'blue': (110,130), 'teal': (90,109),
                       'pink': (140,160)} 
    # ^ note: 190 is above max hue but it should wrap around and start from the beginning again (or we'll just % by 180 ourselves)
    count = 0
    color_grid = []
    for i in range(0,8):
        color_grid.append([])
        for j in range(0,8):
            tl = sorted_intersection_points[i][j]
            # cv2.circle(test_img, tl, 5, (0, 255, 255), -1)
            # if (show_cv):
            #     cv2.imshow('test_img', test_img)
            #     cv2.waitKey(0)
            #     cv2.destroyAllWindows()


            tr = sorted_intersection_points[i][j+1]
            # cv2.circle(test_img, tr, 5, (0, 255, 255), -1)
            # if (show_cv):
            #     cv2.imshow('test_img', test_img)
            #     cv2.waitKey(0)
            #     cv2.destroyAllWindows()

            bl = sorted_intersection_points[i+1][j]
            # cv2.circle(test_img, bl, 5, (0, 255, 255), -1)
            # if (show_cv):
            #     cv2.imshow('test_img', test_img)
            #     cv2.waitKey(0)
            #     cv2.destroyAllWindows()

            br = sorted_intersection_points[i+1][j+1]
            # cv2.circle(test_img, br, 5, (0, 255, 255), -1)
            # if (show_cv):
            #     cv2.imshow('hsv_img', hsv_img)
            #     cv2.waitKey(0)
            #     cv2.destroyAllWindows()

            height, width, _ = img.shape
            mask = np.zeros((height, width), dtype=np.uint8)
            poly = np.array([[tl, tr, br, bl]], dtype=np.int32)
            cv2.fillPoly(mask, poly, 255)
            num_pixels = 1
            hue = 0
            for x in range(min(tl[0],bl[0]), max(tr[0],br[0])):
                for y in range(min(tl[1],tr[1]), max(bl[1],br[1])):
                    # print(hsv_after[y, x, 2])
                    if mask[y,x] == 255 and hsv_after[y, x, 0] != 0:
                        # print(hsv_img[y, x, 2])
                        num_pixels += 1
                        hue += hsv_after[y, x, 0]
            avg_hue = hue / num_pixels
            print(count, avg_hue)
            piece_found = False
            for color, (lower, upper) in hue_thresh_dict.items():
                if avg_hue != 0:
                    if lower <= avg_hue <= upper:
                        color_grid[i].append((color, avg_hue))
                        piece_found = True
                        break # should only do this once
            
            if piece_found == False:
                color_grid[i].append((None,0))

            count += 1

    for row in color_grid:
        for tup in row:
            if tup[0] is None:
                print("\t\t|", end="")
            else:
                print(tup[0],int(tup[1]), "\t|", end="")
        print("")

    if show_cv:
        cv2.imshow('Warped', warped_img)

        hsv_after_intersections = hsv_after.copy()
        for points in sorted_intersection_points:
                for point in points:
                    cv2.circle(hsv_after_intersections, point, 1, (255, 255, 255), -1)
        cv2.imshow('hsv_after_intersections', hsv_after_intersections)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

def sort_square_grid_coords(coordinates, unpacked):
    # this function assumes there are a perfect square amount of coordinates
    sqrt_len = int(math.sqrt(len(coordinates)))
    sorted_coords = sorted(coordinates, key=lambda coord: coord[1]) # first sort by y values
    # then group rows of the square (for example, 9x9 grid would be 81 coordinates so split into 9 arrays of 9)
    rows = [sorted_coords[i:i+sqrt_len] for i in range(0, len(sorted_coords), sqrt_len)]
    for row in rows:
        row.sort(key=lambda coord: coord[0]) # now sort each row by x

    if (unpacked == False):
        return rows
    
    collapsed = [coord for row in rows for coord in row] # unpack/collapse groups to just be an array of tuples
    return collapsed