from pathlib import Path
from skimage import io
from torch.utils.data import Dataset
import numpy as np
from torchvision.ops import masks_to_boxes
import torch
import torchvision.transforms.v2 as T

from torchvision import tv_tensors


def get_image_names(folder):
    start = Path("images/") / Path(folder)

    return sorted([name.relative_to(start) for _, name in enumerate(start.rglob("*.png"))])


def get_masks(indices, annotation):
    masks = torch.zeros((len(indices), annotation.shape[0],
                         annotation.shape[1]), dtype=torch.bool)

    for i, value in enumerate(indices):
        masks[i, :, :] = annotation == value

    return masks


def get_sequence_images(sequence_path):
    return [image_path for image_path in sorted((Path("images") / Path(sequence_path)).rglob("*.png"))]


def get_sequences(folder):
    start = Path("images") / Path(folder)

    sequences = []

    for sequence in sorted(start.iterdir()):
        if sequence.is_dir():
            sequences.append(sequence.relative_to(Path("images")))

    return sequences


def get_image_tensor(image, normalize):

    if normalize:
        transformation = T.Compose(
            [
                T.ToImage(),
                T.ToDtype(torch.float32, scale=True),
                T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
            ]
        )
    else:
        transformation = T.Compose(
            [
                T.ToImage(),
                T.ToDtype(torch.float32, scale=True),
            ]
        )

    return transformation(image)


def remove_empty_boxes(boxes):
    x1_lt_x2 = boxes[:, 0] < boxes[:, 2]
    y1_lt_y2 = boxes[:, 1] < boxes[:, 3]
    non_empty_box = x1_lt_x2 & y1_lt_y2
    return boxes[non_empty_box]


def create_lists(list_of_tuples):
    """
    This transformation is needed for dataloaders.
    """

    # have a list of pairs (or triples) (image + ground truth dicts)
    tuple_of_lists = tuple(zip(*list_of_tuples))

    # need a list of images and a list of dictionaries (optionally list of names)
    return tuple(list(tuple_of_lists[i]) for i in range(len(tuple_of_lists)))


def prepare_images(image, annotation, normalize, augment, device, classes_from_zero, index):
    # ignore regions can also be replaced with random noise
    # random_noise = np.random.uniform(0, 65000, image.shape)
    # image[annotation == 10000] = random_noise[annotation == 10000]

    image[annotation == 10000] = 0

    # perform the same kind of normalization as in the pretrained model
    image = get_image_tensor(image, normalize)

    # get indices of detections
    # leave out the label 0 (background) and 10000 (ignore)
    indices = torch.LongTensor(np.sort(np.unique(annotation)))[1:-1]

    # class for each detection
    classes = indices // 1000

    # transformer needs indexing from zero
    if classes_from_zero:
        classes = classes - 1

    # binary masks for each object
    masks = get_masks(indices, annotation)

    # transform binary masks into bounding boxes surrounding the objects
    boxes = masks_to_boxes(masks)

    # some boxes have zero width/height which causes an error during training
    boxes = remove_empty_boxes(boxes)

    # https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html
    if augment:
        boxes = tv_tensors.BoundingBoxes(
            boxes, format="XYXY", canvas_size=image.shape[-2:])

        masks = tv_tensors.Mask(masks)

        augmentations = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomPhotometricDistort(
                brightness=(0.75, 1.25), hue=(-0.15, 0.15)),
            T.RandomAffine(degrees=(-12, 12),
                           translate=(0.35, 0.35), scale=(0.5, 1.5))
        ])

        image, boxes, masks = augmentations(image, boxes, masks)

        # augmentations might introduce empty boxes
        boxes = remove_empty_boxes(boxes)

    # format expected by mask rcnn
    ground_truth = {
        "boxes": boxes.to(device),
        "masks": masks.to(device),
        "labels": classes.to(device),
        "indices": indices.to(device),
        "image_id": index
    }

    return image.to(device), ground_truth


class DetectionDataset(Dataset):
    def __init__(self, folder, normalize, augment, classes_from_zero, dev):
        super().__init__()
        self.folder = Path(folder)
        self.paths = get_image_names(folder)
        self.dev = dev
        self.normalize = normalize
        self.augment = augment
        self.classes_from_zero = classes_from_zero

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        name = Path(self.paths[index])

        image_path = Path("images/") / self.folder / name
        annotation_path = Path("annotations/") / self.folder / name

        # load image and annotation
        image = io.imread(image_path)
        annotation = io.imread(annotation_path)

        return prepare_images(image, annotation, self.normalize, self.augment, self.dev, self.classes_from_zero, index)


class SequentialDataset(Dataset):
    """
    Loads all the images from one sequence. (Together with their annotations.)
    Used for inference. Also returns the unprocessed image and its path.
    """

    def __init__(self, sequence_folder, normalize, augment, classes_from_zero, dev):
        super().__init__()
        self.sequence_folder = Path(sequence_folder)
        self.image_names = get_image_names(sequence_folder)
        self.dev = dev
        self.augment = augment
        self.normalize = normalize
        self.classes_from_zero = classes_from_zero

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, index):
        name = Path(self.image_names[index])

        image_path = Path("images/") / self.sequence_folder / name
        annotation_path = Path("annotations/") / self.sequence_folder / name

        # load image and annotation
        image = io.imread(image_path)
        annotation = io.imread(annotation_path)

        image_tensor, gt = prepare_images(
            image, annotation, self.normalize, self.augment, self.dev, self.classes_from_zero, index)

        return image_tensor, gt, image, image_path
