import logging
import tarfile
from pathlib import Path
from typing import Union

import imgaug.augmenters as iaa
import numpy as np
import torch
import torchvision.transforms as transforms
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
from PIL import Image
from torch.utils.data import Dataset

logger = logging.getLogger(__name__)
PathLike = Union[Path, str]


class CompressedATRDataset(Dataset):
    loaded_size = 640, 480
    output_size = 640, 640

    def __init__(
        self,
        tar_dataset: PathLike,
        transform=None,
        n_boxes: int = 10,
        extract_to: PathLike = None,
    ):
        from tempfile import TemporaryDirectory

        self.tar_dataset = Path(tar_dataset)
        if extract_to is None:
            temp_dir = TemporaryDirectory()
            temp_dir.cleanup()  # Cleans up on GC
            self.temp_dir = Path(temp_dir.name)
        else:
            self.temp_dir = Path(extract_to)
        tar = tarfile.open(self.tar_dataset)
        tar.extractall(path=self.temp_dir)
        tar.close()
        self.image_files = list(self.temp_dir.glob("**/*.png"))
        logger.info("Loaded %d images", len(self.image_files))
        self.transform = transform
        self.n_ret_boxes = n_boxes

    def __getitem__(self, index):
        image_path = self.image_files[index]
        image = Image.open(image_path).convert("RGB")
        assert image.size == self.loaded_size
        boxes_tensor = self._read_boxes(image_path.with_suffix(".txt"), image)
        if self.transform is not None:
            image_np, boxes_tensor = self.transform(np.array(image), boxes_tensor)
        image_tensor = transforms.ToTensor()(image_np)
        return image_tensor, boxes_tensor, image_path.as_posix()

    def _read_boxes(self, path: Path, image: Image.Image):
        boxes = np.loadtxt(path).reshape(-1, 5)
        boxes[:, [1, 3]] /= image.width
        boxes[:, [2, 4]] /= image.height
        assert boxes.shape[0] <= 10
        n_padding = 10 - boxes.shape[0]
        padding_tensor = np.zeros((n_padding, 5), dtype=float)
        padding_tensor[:, 0] = -1
        boxes = np.concatenate((boxes, padding_tensor), axis=0)
        return torch.tensor(boxes, dtype=torch.float)

    def __len__(self):
        return len(self.image_files)


class DefaultTransforms:
    def __init__(self):
        self.augmentations = iaa.PadToAspectRatio(
            1.0, position="center-center"
        ).to_deterministic()

    def __call__(self, img, boxes):
        # Convert xywh to xyxy
        boxes = np.array(boxes)
        boxes[:, 1:] = xywh2xyxy_np(boxes[:, 1:])

        # Convert bounding boxes to imgaug
        bounding_boxes = BoundingBoxesOnImage(
            [BoundingBox(*box[1:], label=box[0]) for box in boxes], shape=img.shape
        )

        # Apply augmentations
        img, bounding_boxes = self.augmentations(
            image=img, bounding_boxes=bounding_boxes
        )

        # Clip out of image boxes
        bounding_boxes = bounding_boxes.clip_out_of_image()

        # Convert bounding boxes back to numpy
        boxes = np.zeros((len(bounding_boxes), 5))
        for box_idx, box in enumerate(bounding_boxes):
            # Extract coordinates for unpadded + unscaled image
            x1 = box.x1
            y1 = box.y1
            x2 = box.x2
            y2 = box.y2

            # Returns (x, y, w, h)
            boxes[box_idx, 0] = box.label
            boxes[box_idx, 1] = (x1 + x2) / 2
            boxes[box_idx, 2] = (y1 + y2) / 2
            boxes[box_idx, 3] = x2 - x1
            boxes[box_idx, 4] = y2 - y1

        return img, boxes


def xywh2xyxy_np(x: np.ndarray):
    y: np.ndarray = np.zeros_like(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2
    y[..., 1] = x[..., 1] - x[..., 3] / 2
    y[..., 2] = x[..., 0] + x[..., 2] / 2
    y[..., 3] = x[..., 1] + x[..., 3] / 2
    return y


def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
    # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw  # top left x
    y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh  # top left y
    y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw  # bottom right x
    y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh  # bottom right y
    return y