import os
import glob
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageEnhance
import pytorch_lightning as pl
import random
from sklearn.model_selection import train_test_split


def load_datasets(root_dir, random_seed=42):
    """
    Load training, validation, and test image paths from the DRIVE dataset structure.

    Args:
        root_dir (str): Root directory of the DRIVE dataset.
        random_seed (int): Seed for reproducibility of train/val split.

    Returns:
        tuple: Lists of paths for train_images, train_masks, val_images, val_masks, and test_images.
    """
    images = []
    masks = []
    test_images = []

    images_path = os.path.join(root_dir, 'training', 'images')
    masks_path = os.path.join(root_dir, 'training', '1st_manual')
    test_images_path = os.path.join(root_dir, 'test', 'images')

    for file in glob.glob(os.path.join(images_path, '*.tif')):
        image_name = os.path.basename(file)
        images.append(os.path.join(images_path, image_name))
        mask_name = image_name[:3] + 'manual1.gif'
        masks.append(os.path.join(masks_path, mask_name))

    for file in glob.glob(os.path.join(test_images_path, '*.tif')):
        image_name = os.path.basename(file)
        test_images.append(os.path.join(test_images_path, image_name))

    random.seed(random_seed)
    train_images, val_images, train_masks, val_masks = train_test_split(images, masks, test_size=5,
                                                                        random_state=random_seed)

    return train_images, train_masks, val_images, val_masks, test_images


class Data(Dataset):
    """
    Custom dataset for DRIVE segmentation tasks with on-the-fly data augmentation.

    Args:
        root_dir (str): Dataset root directory.
        train (bool): Whether to use training mode with augmentation.
        images (list): List of image file paths.
        masks (list): List of corresponding mask file paths (can be None for test).
        rotate (int): Max rotation angle for augmentation.
        flip (bool): Whether to apply horizontal flipping.
        random_crop (bool): Whether to apply random cropping.
        scale1 (int): Final image size (square).

    Returns:
        torch.Tensor: Transformed image (and mask if available).
    """
    def __init__(self,
                 root_dir,
                 train=True,
                 images=None,
                 masks=None,
                 rotate=40,
                 flip=True,
                 random_crop=True,
                 scale1=512):

        self.root_dir = root_dir
        self.train = train
        self.rotate = rotate
        self.flip = flip
        self.random_crop = random_crop
        self.transform = transforms.ToTensor()
        self.resize = scale1
        self.images, self.groundtruth = images, masks

    def __len__(self):
        return len(self.images)

    def random_crop(self, image, label, crop_size):
        """
        Apply random crop to image and label.
        """
        crop_width, crop_height = crop_size
        w, h = image.size
        left = random.randint(0, w - crop_width)
        top = random.randint(0, h - crop_height)
        right = left + crop_width
        bottom = top + crop_height
        new_image = image.crop((left, top, right, bottom))
        new_label = label.crop((left, top, right, bottom))
        return new_image, new_label

    def random_enhance(self, image):
        """
        Apply random brightness, color, contrast, or sharpness adjustment.
        """
        value = random.uniform(-2, 2)
        random_seed = random.randint(1, 4)
        if random_seed == 1:
            img_enhanceed = ImageEnhance.Brightness(image)
        elif random_seed == 2:
            img_enhanceed = ImageEnhance.Color(image)
        elif random_seed == 3:
            img_enhanceed = ImageEnhance.Contrast(image)
        else:
            img_enhanceed = ImageEnhance.Sharpness(image)
        image = img_enhanceed.enhance(value)
        return image

    def rescale(self, img, re_size):
        """
        Center-crop and resize image to square shape.
        """
        w, h = img.size
        min_len = min(w, h)
        new_w, new_h = min_len, min_len
        scale_w = (w - new_w) // 2
        scale_h = (h - new_h) // 2
        box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
        img = img.crop(box)
        img = img.resize((re_size, re_size))
        return img

    def __getitem__(self, idx):
        """
        Get preprocessed image (and mask if available) at a given index.
        """
        img_path = self.images[idx]
        image = Image.open(img_path)
        image = self.rescale(image, self.resize)

        if self.groundtruth:
            gt_path = self.groundtruth[idx]
            label = Image.open(gt_path)
            label = self.rescale(label, self.resize)

        if self.train:
            # apply augumentations
            angle = random.randint(-self.rotate, self.rotate)
            image = image.rotate(angle)
            label = label.rotate(angle)

            if random.random() > 0.5:
                image = self.random_enhance(image)

            image, label = self.random_crop(image, label, crop_size=[self.resize, self.resize])

            if self.flip and random.random() > 0.5:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                label = label.transpose(Image.FLIP_LEFT_RIGHT)

        else:
            # Resize only if needed (for validation or test)
            img_size = image.size
            if img_size[0] != self.resize:
                image = image.resize((self.resize, self.resize))

                if self.groundtruth:
                    label = label.resize((self.resize, self.resize))

        image = self.transform(image)

        if self.groundtruth:
            label = self.transform(label)
            label = (label > 0.5).float()
            return image, label

        return image


class DriveDataset(pl.LightningDataModule):
    """
    PyTorch Lightning DataModule for the DRIVE dataset.

    Args:
        root_dir (str): Root directory containing training and test folders.
        batch_size (int): Batch size for training and evaluation.
        num_workers (int): Number of subprocesses used for data loading.
    """
    def __init__(self, root_dir, batch_size=1, num_workers=4):
        super(DriveDataset, self).__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage=None):
        """
        Split dataset and initialize train, val, and test sets.
        Called automatically by Lightning before training/testing.
        """
        train_images, train_masks, val_images, val_masks, test_images = load_datasets(self.root_dir)
        if stage == 'fit' or stage is None:
            self.train_dataset = Data(self.root_dir, train=True, images=train_images, masks=train_masks)
            self.val_dataset = Data(self.root_dir, train=False, images=val_images, masks=val_masks)
        if stage == 'test' or stage is None:
            self.test_dataset = Data(self.root_dir, train=False, images=test_images, masks=None)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          shuffle=True,
                          num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers)
