import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pytorch_lightning as pl
import random
from sklearn.model_selection import train_test_split
import nibabel as nib
from utils.image_processing import get_smaller_structures_mask


def load_datasets(root="~/DP/data/syn_3d.py/extracted_data/"):
    """
    Load raw image and corresponding segmentation mask paths from the dataset directory.

    Args:
        root (str): Root directory containing 'raw' and 'seg' subdirectories.

    Returns:
        tuple: Two lists of paths - images and masks.
    """
    images = []
    masks = []

    images_path = os.path.join(root, 'raw')
    masks_path = os.path.join(root, 'seg')

    for file in glob.glob(os.path.join(images_path, '*.nii.gz')):
        image_name = os.path.basename(file)
        images.append(os.path.join(images_path, image_name))
        masks.append(os.path.join(masks_path, image_name))

    return images, masks


def make_splits(images, masks, random_seed=42):
    """
    Split the dataset into train, validation, and test sets.

    Args:
        images (list): List of image file paths.
        masks (list): List of mask file paths.
        random_seed (int): Random seed for reproducibility.

    Returns:
        tuple: Lists of train/val/test images and masks.
    """
    random.seed(random_seed)
    train_images, test_images, train_masks, test_masks = train_test_split(images, masks, test_size=0.2,
                                                                          random_state=random_seed)
    train_images, val_images, train_masks, val_masks = train_test_split(train_images, train_masks, test_size=0.2,
                                                                        random_state=random_seed)

    return train_images, train_masks, val_images, val_masks, test_images, test_masks


class Syn3D(Dataset):
    """
    PyTorch Dataset for loading 3D synthetic medical image volumes with optional augmentation and noise.

    Args:
        root_dir (str): Path to dataset root.
        images (list): List of image file paths (.nii.gz).
        masks (list): List of corresponding mask file paths (.nii.gz).
        split (str): One of ['train', 'val', 'test'] to determine augmentation behavior.
        add_noise (float): Std of Gaussian noise to add. If < 0, a random std [0, 0.45] will be sampled each time.
    """
    def __init__(self,
                 root_dir,
                 images=None,
                 masks=None,
                 split='train',
                 add_noise=0):
        self.root_dir = root_dir
        self.images = images
        self.masks = masks
        self.split = split
        self.transform = transforms.ToTensor()
        self.add_noise = add_noise

    def __len__(self):
        return len(self.images)

    def random_crop(self, image, mask, crop_factor=(0, 0, 0)):
        """
        Perform random cropping of 3D image and mask volumes.

        Args:
            image (ndarray): 3D image volume.
            mask (ndarray): 3D segmentation mask.
            crop_factor (tuple): Desired output dimensions (z, y, x).

        Returns:
            tuple: Cropped image and mask volumes.
        """
        w, h, d = image.shape
        z = random.randint(0, w - crop_factor[0])
        y = random.randint(0, h - crop_factor[1])
        x = random.randint(0, d - crop_factor[2])

        image = image[z:z + crop_factor[0], y:y + crop_factor[1], x:x + crop_factor[2]]
        mask = mask[z:z + crop_factor[0], y:y + crop_factor[1], x:x + crop_factor[2]]
        return image, mask

    def add_gausian_noise(self, image, std=None):
        """
        Add Gaussian noise to the 3D image tensor.

        Args:
            image (Tensor): Input image tensor of shape (1, D, H, W).
            std (float, optional): Standard deviation of the noise.

        Returns:
            Tensor: Noisy image tensor.
        """
        if std is None:
            std = self.add_noise
        noise = torch.randn_like(image) * std
        noisy_image = image + noise
        noisy_image = torch.clamp(noisy_image, 0, 1)
        return noisy_image

    def __getitem__(self, idx):
        """
        Load and process a sample from the dataset.

        Returns:
            tuple: (noisy_image, ground_truth_mask, postprocessed_mask)
        """

        image = nib.load(self.images[idx]).get_fdata()
        mask = nib.load(self.masks[idx]).get_fdata()

        image, mask = self.random_crop(image, mask, crop_factor=(128, 128, 128))

        opened_mask = get_smaller_structures_mask(mask, ker_size=3)

        image = torch.tensor(image / 255, dtype=torch.float32)
        mask = torch.tensor(mask, dtype=torch.float32)

        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)
        opened_mask = torch.tensor(opened_mask, dtype=torch.float32).unsqueeze(0)

        std = self.add_noise
        if self.add_noise < 0:
            std = random.uniform(0, 0.45)

        noisy_image = self.add_gausian_noise(image, std=std)

        return noisy_image, mask, opened_mask


class Syn3DLightning(pl.LightningDataModule):
    """
    PyTorch Lightning DataModule for handling the Syn3D dataset.

    Args:
        root_dir (str): Root path to the dataset.
        batch_size (int): Number of samples per batch.
        num_workers (int): Number of subprocesses used for data loading.
        add_noise (float): Standard deviation for Gaussian noise to apply.
    """
    def __init__(self, root_dir, batch_size=1, num_workers=4, add_noise=0):
        super(Syn3DLightning, self).__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.add_noise = add_noise

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage=None):
        """
        Setup datasets for different stages: 'fit' (train/val) and 'test'.

        Args:
            stage (str, optional): Either 'fit', 'test', or None for both.
        """
        images, masks = load_datasets(root=self.root_dir)
        train_images, train_masks, val_images, val_masks, test_images, test_masks = make_splits(images, masks)
        if stage == 'fit' or stage is None:
            self.train_dataset = Syn3D(root_dir=self.root_dir, images=train_images,
                                       masks=train_masks, split='train', add_noise=self.add_noise)
            self.val_dataset = Syn3D(root_dir=self.root_dir, images=val_images,
                                     masks=val_masks, split='val', add_noise=self.add_noise)
        if stage == 'test' or stage is None:
            self.test_dataset = Syn3D(root_dir=self.root_dir, images=test_images,
                                      masks=test_masks, split='test', add_noise=self.add_noise)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          shuffle=True,
                          num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers)
