import numpy as np
import torch
from torchmetrics import Metric
from scipy.spatial import cKDTree
import torch.nn as nn
import torch.nn.functional as F


class Sensitivity(Metric):
    """
    Computes Sensitivity (Recall) = TP / (TP + FN)

    Args:
        threshold (float, optional): Threshold for binarizing predictions. Default is 0.5.
    """

    def __init__(self, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)  # Remove compute_on_step

        self.threshold = threshold

        # Add metric states (TP and FN)
        self.add_state("true_positives", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("false_negatives", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """
        Update metric states with new batch of predictions and targets.

        Args:
            preds (torch.Tensor): Model predictions (logits or probabilities).
            target (torch.Tensor): Ground truth labels (binary).
        """
        # Convert predictions to binary using the threshold
        preds = (preds >= self.threshold).float()
        target = target.float()  # Ensure target is float for calculations

        # Compute TP and FN
        tp = torch.sum((preds == 1) & (target == 1))
        fn = torch.sum((preds == 0) & (target == 1))

        # Accumulate values
        self.true_positives += tp
        self.false_negatives += fn

    def compute(self):
        """
        Computes final sensitivity score across accumulated batches.
        """
        return self.true_positives / (self.true_positives + self.false_negatives + 1e-10)  # Avoid division by zero


class HausdorffDistance(nn.Module):
    """
    Computes the Hausdorff Distance (HD) and the 95th percentile Hausdorff Distance (HD95)
    for 3D binary segmentation masks in PyTorch.
    """

    def __init__(self, percentile=95, threshold=0.5):
        """
        Initializes the HausdorffDistance class.

        :param percentile: Percentile for the HD computation (default is 95).
        :param threshold: Binarization threshold for segmentation masks.
        """
        super().__init__()
        self.percentile = percentile
        self.threshold = threshold
        self.device = torch.device("cpu")  # Default device

    def to(self, device):
        """
        Ensures the metric is moved to the correct device.
        """
        self.device = device
        return self  # Return self to allow chaining `.to(device)`

    # def forward(self, seg1, seg2):
    #     """
    #     Computes the Hausdorff Distance (HD) and the 95th percentile Hausdorff Distance (HD95).
    #
    #     :param seg1: Binary 3D PyTorch tensor (segmentation mask 1)
    #     :param seg2: Binary 3D PyTorch tensor (segmentation mask 2)
    #     :return: (HD, HD95) as a tuple
    #     """
    #     seg1 = seg1.to(self.device)
    #     seg2 = seg2.to(self.device)
    #
    #     # Ensure binary masks
    #     seg1 = (seg1 > self.threshold).float()
    #     seg2 = (seg2 > self.threshold).float()
    #
    #     # Extract foreground voxel coordinates
    #     seg1_points = torch.nonzero(seg1, as_tuple=False)
    #     seg2_points = torch.nonzero(seg2, as_tuple=False)
    #
    #     if seg1_points.shape[0] == 0 or seg2_points.shape[0] == 0:
    #         raise ValueError("One or both segmentations are empty!")
    #
    #     # Compute pairwise distances using PyTorch (equivalent to KDTree)
    #     dists_1_to_2 = self._min_distance(seg1_points, seg2_points)
    #     dists_2_to_1 = self._min_distance(seg2_points, seg1_points)
    #
    #     # Compute Hausdorff distance (max distance)
    #     hd = torch.max(torch.cat([dists_1_to_2, dists_2_to_1])).item()
    #
    #     # Compute 95th percentile Hausdorff distance
    #     hd95 = torch.quantile(torch.cat([dists_1_to_2, dists_2_to_1]), self.percentile / 100.0).item()
    #
    #     return hd, hd95

    def forward(self, seg1, seg2):
        """
        Computes the Hausdorff Distance (HD) and the 95th percentile Hausdorff Distance (HD95).

        :param seg1: Binary 3D PyTorch tensor (segmentation mask 1)
        :param seg2: Binary 3D PyTorch tensor (segmentation mask 2)
        :return: (HD, HD95) as a tuple of torch.Tensors (floats or NaN if invalid)
        """
        seg1 = seg1.to(self.device)
        seg2 = seg2.to(self.device)

        # Ensure binary masks
        seg1 = (seg1 > self.threshold).float()
        seg2 = (seg2 > self.threshold).float()

        # Extract foreground voxel coordinates
        seg1_points = torch.nonzero(seg1, as_tuple=False)
        seg2_points = torch.nonzero(seg2, as_tuple=False)

        if seg1_points.numel() == 0 or seg2_points.numel() == 0:
            # Return NaN tensors instead of crashing
            return torch.tensor(float('nan'), device=self.device), torch.tensor(float('nan'), device=self.device)

        # Compute pairwise distances using PyTorch (equivalent to KDTree)
        dists_1_to_2 = self._min_distance(seg1_points, seg2_points)
        dists_2_to_1 = self._min_distance(seg2_points, seg1_points)

        # Combine distances
        all_dists = torch.cat([dists_1_to_2, dists_2_to_1])

        # Compute Hausdorff distance (max distance)
        hd = torch.max(all_dists)

        # Compute 95th percentile Hausdorff distance
        hd95 = torch.quantile(all_dists, self.percentile / 100.0)

        return hd, hd95

    def _min_distance(self, points1, points2):
        """
        Computes the minimum distance from each point in points1 to the closest point in points2.

        :param points1: Tensor of shape (N, 3) - voxel positions
        :param points2: Tensor of shape (M, 3) - voxel positions
        :return: Tensor of distances of shape (N,)
        """
        if points1.shape[0] == 0 or points2.shape[0] == 0:
            return torch.tensor([float("inf")], device=self.device)

        # Compute pairwise squared distances
        dists = torch.cdist(points1.float(), points2.float(), p=2)  # Euclidean distance
        min_dists, _ = torch.min(dists, dim=1)  # Minimum distance per point
        return min_dists
