Skip to content
Snippets Groups Projects
board_detector.py 21.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • zalonzo2's avatar
    zalonzo2 committed
    # import os
    from PIL import Image
    
    zalonzo2's avatar
    zalonzo2 committed
    import psutil
    import time
    
    
    # global show_cv because I didn't want to have show_cv as an input to every function
    
    # also need skip_camera so I can use cv2.imshow if I'm using my PC
    
    skip_camera = None
    img_size = None
    def init_global(bool1, bool2, bool3):
    
        global skip_camera
        global img_size
        show_cv = bool1
        skip_camera = bool2
        img_size = bool3
    
    zalonzo2's avatar
    zalonzo2 committed
    # had to write a custom img display function because conflict with picamera2 and opencv made it so I can't use imshow
    # (I had to install headless opencv to remove conflict which removes imshow)
    def display_img(img_array):
    
        max_windows = 5
        if len(img_array) < max_windows: # prevent spamming windows if accidentally input an image instead of array of images (guess why this is here)
            if skip_camera: # if not running on raspi
                for i, cv2_img in enumerate(img_array):
                    cv2.imshow('Image ' + str(i+1), cv2_img)
                cv2.waitKey(0)
                cv2.destroyAllWindows()
            else: # if running on raspi
                for cv2_img in img_array:
                    pil_img = cv2_to_pil(cv2_img)
                    pil_img.show()
                input()
                for proc in psutil.process_iter():
                    if proc.name() == "display":
                        proc.kill()
        else:
            print(f"Too many images (>{max_windows})")
    
    def cv2_to_pil(cv2_img):
        rgb_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(rgb_img)
        return pil_img
    
    def pil_to_cv2(pil_image):
        cv2_img = np.array(pil_image, dtype=np.uint8)
        cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
        return cv2_img
    
    def find_lines(img):
    
        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
        edges = cv2.Canny(gray_img, 50, 100, apertureSize=3)
    
    zalonzo2's avatar
    zalonzo2 committed
        h,w = edges.shape
    
        # don't detect any lines at the edges (make the edges black)
    
    zalonzo2's avatar
    zalonzo2 committed
        for y in range(0,h):
            for x in range(0,w):
                if x < cutoff or x > w - cutoff or y < cutoff or y > h - cutoff:
                    edges[y, x] = 0
    
    
    zalonzo2's avatar
    zalonzo2 committed
            display_img([edges])
    
        horizontal_lines = cv2.HoughLines(edges, rho=1, theta=np.pi/180, threshold=60, min_theta=(theta_thresh/2-1)*np.pi/theta_thresh, max_theta=(theta_thresh/2+1)*np.pi/theta_thresh)
        vertical_lines = cv2.HoughLines(edges, rho=1, theta=np.pi/180, threshold=60, min_theta=-np.pi/theta_thresh, max_theta=np.pi/theta_thresh)
    
        vertical_line_points = convert_to_cartesian(vertical_lines)
        horizontal_line_points = convert_to_cartesian(horizontal_lines)
    
        # filter lines too close to each other
    
        filtered_vertical = filter_lines(vertical_line_points, int(img_size/10))
        filtered_horizontal = filter_lines(horizontal_line_points, int(img_size/10))
    
        # get the 9 largest lines
    
        sorted_vertical = sorted(filtered_vertical, key=lambda line: min(line[0][1], line[0][3]))[:9]
        sorted_horizontal = sorted(filtered_horizontal, key=lambda line: min(line[0][0], line[0][2]))[:9]
    
        return sorted_vertical, sorted_horizontal
    
    
    def convert_to_cartesian(lines):
    
        if lines is not None:
            for line in lines:
    
                rho, theta = line[0]
    
                cos_theta  = np.cos(theta)
                sin_theta = np.sin(theta)
    
                x0 = cos_theta * rho
                y0 = sin_theta * rho
    
                x1 = int(x0 + 1000 * (-sin_theta))
                y1 = int(y0 + 1000 * (cos_theta))
                x2 = int(x0 - 1000 * (-sin_theta))
                y2 = int(y0 - 1000 * (cos_theta))
    
                line_points.append([[x1,y1,x2,y2]])
        return line_points
    
    
    def filter_lines(lines, min_distance):
        filtered_lines = []
    
        # filter out lines too close to each other
        # (this assumes lines are around the same size and parallel)
        # (extremely simplified to improve computational speed because this is all we need)
    
        if lines is not None:
            for line1 in lines:
                x1, y1, x2, y2 = line1[0]
                line1_x_avg = (x1 + x2) / 2
                line1_y_avg = (y1 + y2) / 2
                keep_line = True
                for line2 in filtered_lines:
                    x3, y3, x4, y4 = line2[0]
    
    zalonzo2's avatar
    zalonzo2 committed
                    
    
                    line2_x_avg = (x3 + x4) / 2
                    line2_y_avg = (y3 + y4) / 2
    
                    # calculate dist between average points of the 2 lines
                    dist = np.sqrt((line1_x_avg - line2_x_avg)**2 + (line1_y_avg - line2_y_avg)**2)
    
                    if dist < min_distance:
                        keep_line = False
                        break
    
                if keep_line:
                    filtered_lines.append(line1)
    
        vertical_lines, horizontal_lines = find_lines(img)
    
    
        # create bitmasks for vert and horiz so we can get lines and intersections
        height, width, _ = img.shape
        vertical_mask = np.zeros((height, width), dtype=np.uint8)
        horizontal_mask = np.zeros((height, width), dtype=np.uint8)
    
        for line in vertical_lines:
            x1, y1, x2, y2 = line[0]
            cv2.line(vertical_mask, (x1, y1), (x2, y2), (255), 2)
    
        for line in horizontal_lines:
            x1, y1, x2, y2 = line[0]
            cv2.line(horizontal_mask, (x1, y1), (x2, y2), (255), 2)
    
    
        # get lines and intersections of grid and corresponding contours
    
        intersection = cv2.bitwise_and(vertical_mask, horizontal_mask)
        board_lines = cv2.bitwise_or(vertical_mask, horizontal_mask)
    
        contours, hierarchy = cv2.findContours(board_lines, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    
        if (show_cv):
            intersections, hierarchy = cv2.findContours(intersection, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    
            # find midpoints of intersection contours (this gives us exact coordinates of the corners of the grid)
            intersection_points = []
            for contour in intersections:
                M = cv2.moments(contour)
                if (M["m00"] != 0):
                    midpoint_x = int(M["m10"] / M["m00"])
                    midpoint_y = int(M["m01"] / M["m00"])
                    intersection_points.append((midpoint_x, midpoint_y))
    
            # sort the coordinates from left to right then top to bottom
            sorted_intersection_points = sort_square_grid_coords(intersection_points, unpacked=False)
    
    
            board_lines_img = img.copy()
            cv2.drawContours(board_lines_img, contours, -1, (255, 255, 0), 2)
    
            i = 0
            for points in sorted_intersection_points:
                for point in points:
                    cv2.circle(board_lines_img, point, 5, (255 - (3 * i), (i % 9) * 28, 3 * i), -1)
                    i += 1
    
    zalonzo2's avatar
    zalonzo2 committed
            display_img([board_lines_img])
    
        if (len(vertical_lines) != 9 or len(horizontal_lines) != 9):
            print("Error: Grid does not match expected 9x9")
            print("# of Vertical:",len(vertical_lines))
            print("# of Horizontal:",len(horizontal_lines))
            return None, None
    
    
        # find largest contour and get rid of it because it contains weird edges from lines
        max_area = 100000 # we're assuming board is going to be big (hopefully to speed up computation on raspberry pi)
        largest = -1
        for i, contour in enumerate(contours):
            area = cv2.contourArea(contour)
            if area > max_area:
                max_area = area
                largest = i
        # "largest" is index of largest contour
    
        # get rid of contour containing the edges of the lines   
        contours = list(contours)
        contours.pop(largest)
        contours = tuple(contours)
    
        # thicken lines so that connections are made
        contour_mask = np.zeros((height, width), dtype=np.uint8)
        cv2.drawContours(contour_mask, contours, -1, (255), thickness=10)
        thick_contours, _ = cv2.findContours(contour_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
        # obtain largest contour of the thickened lines (the border) and approximate a 4 sided polygon onto it
        max_area = 100000
        largest = -1
        max_rect = None
        for i, contour in enumerate(thick_contours):
            area = cv2.contourArea(contour)
            if area > max_area:
                epsilon = 0.05 * cv2.arcLength(contour, True)
                rect = cv2.approxPolyDP(contour, epsilon, True) # uses Douglas-Peucker algorithm (probably overkill)
                if (len(rect) == 4):
                    max_area = area
                    largest = i
                    max_rect = rect
    
        # perspective transform based on rectangle outline of board
    
        corners = max_rect.reshape(-1, 2) # reshapes it so each row has 2 elements
        corners = [tuple(corner) for corner in corners] # convert to tuples
    
        corners_sorted = sort_square_grid_coords(corners, unpacked=True)
    
        src = np.float32([list(tl), list(tr), list(bl), list(br)])
        dest = np.float32([[0,0], [width, 0], [0, height], [width, height]])
        M = cv2.getPerspectiveTransform(src, dest)
        Minv = cv2.getPerspectiveTransform(dest, src)
        warped_img = img.copy()
        warped_img = cv2.warpPerspective(np.uint8(warped_img), M, (width, height))
    
    
        # perspective transform the intersections as well
    
        M = cv2.getPerspectiveTransform(src, dest)
        Minv = cv2.getPerspectiveTransform(dest, src)
    
        warped_ip = intersection.copy() # warped intersection points
    
        warped_ip = cv2.warpPerspective(np.uint8(warped_ip), M, (width, height))
    
    
        warped_intersections, hierarchy = cv2.findContours(warped_ip, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    
        # find midpoints of warped intersection contours (this gives us exact coordinates of the corners of the grid)
        warped_intersection_points = []
        for contour in warped_intersections:
    
            M = cv2.moments(contour)
            if (M["m00"] != 0):
                midpoint_x = int(M["m10"] / M["m00"])
                midpoint_y = int(M["m01"] / M["m00"])
    
                warped_intersection_points.append((midpoint_x, midpoint_y))
    
    
        # sort the coordinates from left to right then top to bottom
    
        sorted_warped_points = sort_square_grid_coords(warped_intersection_points, unpacked=False)
    
        if (show_cv):
            contours_img = img.copy()
            cv2.drawContours(contours_img, thick_contours, -1, (0, 255, 0), 2)
            cv2.drawContours(contours_img, [thick_contours[largest]], -1, (0, 0, 255), 2)
            cv2.drawContours(contours_img, [max_rect], -1, (255, 0, 0), 2)
    
                cv2.circle(contours_img, (x, y), 5, (0, 255, 255), -1)
    
    zalonzo2's avatar
    zalonzo2 committed
            display_img([contours_img])
    
        return warped_img, sorted_warped_points
    
    def find_pieces(warped_img, sorted_warped_points):
    
        hsv_img = cv2.cvtColor(warped_img, cv2.COLOR_BGR2HSV)
    
        gray_img = cv2.cvtColor(warped_img, cv2.COLOR_BGR2GRAY)
    
        # threshold to find strongest colors in image
    
        saturation_thresh = 60
        brightness_thresh = 80
        hsv_mask_sat = cv2.inRange(hsv_img[:,:,1], saturation_thresh, 255) # saturation mask
        hsv_mask_bright = cv2.inRange(hsv_img[:,:,2], brightness_thresh, 255) # brightness mask
    
        hsv_mask = cv2.bitwise_and(hsv_mask_sat, hsv_mask_bright)
    
    
        hsv_after = cv2.bitwise_and(hsv_img, hsv_img, mask=hsv_mask)
    
        bgr_after = cv2.cvtColor(hsv_after, cv2.COLOR_HSV2BGR)
    
        gray_after = cv2.cvtColor(bgr_after, cv2.COLOR_BGR2GRAY)
    
    zalonzo2's avatar
    zalonzo2 committed
        # define color thresholds to use to classify colors later on
    
        hue_thresh_dict = {'red': (170,190), 'orange':(8,18), 'yellow': (18,44), 'green': (50,70), 'purple': (120,140), 
                           'teal': (80,105), 'pink': (140,170)} # CHANGE
    
        if (show_cv):
            warped_img_pil = cv2_to_pil(warped_img)
            warped_img_draw = ImageDraw.Draw(warped_img_pil)
    
    
        reader = easyocr.Reader(['en'], gpu=False, verbose=False)
    
        filled_contour_mask = np.zeros_like(hsv_after)
    
        # loop through each square of chess board
    
        for i in range(0,8):
            color_grid.append([])
            for j in range(0,8):
    
                # establish corners of current square
                tl = sorted_warped_points[i][j]
                tr = sorted_warped_points[i][j+1]
                bl = sorted_warped_points[i+1][j]
                br = sorted_warped_points[i+1][j+1]
    
    
                # # create a polygon mask for current grid square 
                # # (this is because square is not always perfectly square so I can't just loop pixels from tl to tr then bl to br)
                # # (this might not be worth the extra computation required to shave off a few pixels from being considered)
                # height, width, _ = warped_img.shape
                # rect_mask = np.zeros((height, width), dtype=np.uint8)
                # poly = np.array([[tl, tr, br, bl]], dtype=np.int32)
                # cv2.fillPoly(rect_mask, poly, 255)
                # num_pixels = 1
                # hue = 0
    
                # # loop through a perfect square that is slightly bigger than actual "square." obtain average hue of color in grid square
                # for x in range(min(tl[0],bl[0]), max(tr[0],br[0])):
                #     for y in range(min(tl[1],tr[1]), max(bl[1],br[1])):
                #         if rect_mask[y,x] == 255 and hsv_after[y, x, 0] != 0:
                #             num_pixels += 1
                #             hue += hsv_after[y, x, 0]
                # avg_hue = hue / num_pixels
    
    
                height, width, _ = warped_img.shape
                rect_mask = np.zeros((height, width), dtype=np.uint8)
    
                poly = np.array([[tl, tr, br, bl]], dtype=np.int32)
    
    
                masked_hsv_after = cv2.bitwise_and(gray_after, gray_after, mask=rect_mask)
                # display_img([masked_hsv_after])
                contours, hierarchy = cv2.findContours(masked_hsv_after, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
                # print(contours)
    
    
                if contours is not None:
                    try:
                        largest_contour = max(contours, key=cv2.contourArea)
    
                        # print(cv2.contourArea(largest_contour))
                        # if cv2.contourArea(largest_contour) < 50:
                        #     largest_contour = None
    
                        largest_contour = None
                    
                    if largest_contour is not None:
                        cv2.drawContours(filled_contour_mask, [largest_contour], -1, (255, 0, 0), thickness=cv2.FILLED)
    
                        cur_bounding_box = cv2.boundingRect(largest_contour)
    
                        # cv2.drawContours(warped_img, largest_contour, -1, (255, 255, 0), 2)
                        # display_img([filled_contour_mask])
                        num_pixels = 1
                        hue = 0
                        for x in range(min(tl[0],bl[0]), max(tr[0],br[0])):
                            for y in range(min(tl[1],tr[1]), max(bl[1],br[1])):
                                if filled_contour_mask[y, x, 0] > 0 and hsv_after[y, x, 0] != 0:
                                    num_pixels += 1
                                    hue += hsv_after[y, x, 0]
                        avg_hue = hue / num_pixels
    
                        if num_pixels < pixel_thresh:
                            cv2.drawContours(filled_contour_mask, [largest_contour], -1, (0, 0, 0), thickness=cv2.FILLED)
    
                        # for pixel in largest_contour:
                        #     y,x = pixel[0]
                        #     # display_img([hsv_after])
                        #     if hsv_after[y, x, 0] != 0:
                        #             num_pixels += 1
                        #             hue += hsv_after[y, x, 0]
                        # avg_hue = hue / num_pixels
    
    
                # if there is a color in square, then label based on custom hue thresholds and add to color_grid
    
                if largest_contour is not None:
                    if avg_hue != 0:
                        for color, (lower, upper) in hue_thresh_dict.items():
                            if lower <= avg_hue < upper:
                                x, y, w, h = cur_bounding_box
                                border = 10
                                bl = (max(y-border,0),max(x-border,0))
                                tr = (min(y+h+border,img_size),min(x+w+border,img_size))
                                img_to_read = masked_hsv_after[bl[0]:tr[0], bl[1]:tr[1]]
                                img_to_read[img_to_read != 0] = 255 
                                result = reader.readtext(
                                    image=img_to_read,
                                    allowlist="PKQRBN",
                                    rotation_info=[180],
    
                                    min_size = 5
                                )
                                if len(result) != 0:
                                    bound_box, letter, confidence = result[0]
    
                                color_grid[i].append([color, avg_hue, num_pixels, letter])
                                piece_found = True
                                if show_cv and num_pixels > pixel_thresh:
                                    y,x = tl[0] + 5, tl[1] + 5
                                    warped_img_draw.text((y,x), color, fill=(255, 0, 0)) # draw color found onto image
    
                    color_grid[i].append([None, avg_hue, num_pixels, None])
    
            warped_img_draw = pil_to_cv2(warped_img_draw._image)
    
            bgr_after_intersections = bgr_after.copy()
            for points in sorted_warped_points:
    
                        cv2.circle(bgr_after_intersections, point, 1, (255, 255, 255), -1)
    
            display_img([warped_img_draw, bgr_after_intersections, filled_contour_mask])
    
        # reader = easyocr.Reader(['en'], gpu=False)
        # results = reader.readtext(
        #     image=warped_img,
        #     allowlist="PKQRBNpkqrbn",
        #     # rotation_info=[180],
        #     # batch_size=2,
        #     # text_threshold=0.01,
        #     # slope_ths=0.0
        # )
        # print(len(results))
        # for text, _, _ in results:
        #     print(text)
    
    
        # gray_masked = cv2.bitwise_and(gray_after, gray_after, mask=filled_contour_mask[:,:,0])
        # gray_masked[gray_masked != 0] = 255
        # scale = 1
        # size = img_size * scale
        # img_to_read = cv2.resize(gray_masked, (int(size), int(size)))
        # reader = easyocr.Reader(['en'])
        # results = reader.readtext(
        #     image=img_to_read,
        #     allowlist="PKQRBN",
        #     rotation_info=[180],
        #     batch_size=500,
        #     text_threshold=0.3,
        #     link_threshold = 100000000000,
        #     slope_ths=0.0,
        #     ycenter_ths=0.01,
        #     height_ths=0.01,
        #     width_ths = 0.01,
        #     min_size=int(size/150)
        # )
    
        # img_to_read = cv2.cvtColor(img_to_read, cv2.COLOR_GRAY2BGR)
        # img_to_read = cv2.resize(img_to_read, (img_size, img_size))
    
        # # if show_cv:
        # img_to_read_pil = cv2_to_pil(img_to_read)
        # img_to_draw = ImageDraw.Draw(img_to_read_pil)
    
        # for result in results:
        #     bound_box, letter = result[0:2]
        #     y_min, x_min = [int(min(val)/scale) for val in zip(*bound_box)]
        #     y_max, x_max = [int(max(val)/scale) for val in zip(*bound_box)]
        #     # cv2.circle(img_to_read, (int(x_min + (x_max - x_min)/2), int(y_min + (y_max - y_min)/2)), 1, (0, 255, 0), 2)
        #     # cv2.circle(img_to_read, (x_min, y_min), 1, (0, 255, 0), 2)
        #     # cv2.putText(img_to_read, letter, (x_min, y_min), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
    
        #     y_center, x_center = (int(y_min + (y_Fmax - y_min)/2), int(x_min + (x_max - x_min)/2))
    
    
        #     x_grid = int(x_center / img_size * 8)
        #     y_grid = int(y_center / img_size * 8)
        #     # print(y_grid, x_grid, letter)
        #     color_grid[x_grid][y_grid][3] = letter
    
        #     # if show_cv:
        #     img_to_draw.text((y_min - 10, x_min), letter, fill=(255, 0, 0))
    
        # print color_grid. only print when the color is found a lot in the square (> pixel_thresh times)
    
        # if show_cv:
        #     # print("|avg_hue, num_pixels, letter|")
        #     for row in color_grid:
        #         print("||", end="")
        #         for color, avg_hue, num_pixels, letter in row:
        #             if num_pixels > pixel_thresh:
        #                 print(f"{int(avg_hue)},   {letter}\t|", end="")
        #                 # print(f"{color},   {letter}\t|", end="")
        #             else:
        #                 print("\t\t|", end="")
        #         print("|")
    
    
        # img_to_draw._image.save('game_images/ocr_results.jpg') 
        # if show_cv:
        #     img_to_draw = pil_to_cv2(img_to_draw._image)
        #     display_img([img_to_draw])
    
    def sort_square_grid_coords(coordinates, unpacked):
    
        # this function assumes there are a perfect square amount of coordinates
    
        sqrt_len = int(math.sqrt(len(coordinates)))
        sorted_coords = sorted(coordinates, key=lambda coord: coord[1]) # first sort by y values
        # then group rows of the square (for example, 9x9 grid would be 81 coordinates so split into 9 arrays of 9)
    
        rows = [sorted_coords[i:i+sqrt_len] for i in range(0, len(sorted_coords), sqrt_len)]
        for row in rows:
            row.sort(key=lambda coord: coord[0]) # now sort each row by x
    
        if (unpacked == False):
    
        collapsed = [coord for row in rows for coord in row] # unpack/collapse groups to just be an array of tuples
        return collapsed