from __future__ import division
import os.path
from medpy.metric import hd
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchmetrics import Specificity, JaccardIndex
from torchmetrics.classification import BinaryAccuracy
from utils.my_metrics import Sensitivity, HausdorffDistance


def downsample():
    """3D max pooling for downsampling."""
    return nn.MaxPool3d(kernel_size=2, stride=2)


def deconv(in_channels, out_channels):
    """3D transposed convolution for upsampling."""
    return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)


def initialize_weights(*models):
    """Initialize weights using Kaiming normal for Conv3D and Linear layers."""
    for model in models:
        for m in model.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class ResEncoder3d(nn.Module):
    """Residual block with two Conv3D layers and a 1x1 skip connection."""
    def __init__(self, in_channels, out_channels):
        super(ResEncoder3d, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=False)
        self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        residual = self.conv1x1(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = out + residual
        out = self.relu(out)
        return out


class Decoder3d(nn.Module):
    """Decoder block with two Conv3D layers."""
    def __init__(self, in_channels, out_channels):
        super(Decoder3d, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=False),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=False)
        )

    def forward(self, x):
        out = self.conv(x)
        return out


class SpatialAttentionBlock3d(nn.Module):
    """Spatial attention using anisotropic convolutions in 3D."""
    def __init__(self, in_channels):
        super(SpatialAttentionBlock3d, self).__init__()
        self.query = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(1, 3, 1), padding=(0, 1, 0))
        self.key = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(3, 1, 1), padding=(1, 0, 0))
        self.judge = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(1, 1, 3), padding=(0, 0, 1))
        self.value = nn.Conv3d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
        :param x: input( BxCxHxWxZ )
        :return: affinity value + x
        B: batch size
        C: channels
        H: height
        W: width
        D: slice number (depth)
        """
        B, C, H, W, D = x.size()
        # compress x: [B,C,H,W,Z]-->[B,H*W*Z,C], make a matrix transpose
        proj_query = self.query(x).view(B, -1, W * H * D).permute(0, 2, 1)  # -> [B,W*H*D,C]
        proj_key = self.key(x).view(B, -1, W * H * D)  # -> [B,H*W*D,C]
        proj_judge = self.judge(x).view(B, -1, W * H * D).permute(0, 2, 1)  # -> [B,C,H*W*D]

        affinity1 = torch.matmul(proj_query, proj_key)
        affinity2 = torch.matmul(proj_judge, proj_key)
        affinity = torch.matmul(affinity1, affinity2)
        affinity = self.softmax(affinity)

        proj_value = self.value(x).view(B, -1, H * W * D)  # -> C*N
        weights = torch.matmul(proj_value, affinity)
        weights = weights.view(B, C, H, W, D)
        out = self.gamma * weights + x
        return out


class ChannelAttentionBlock3d(nn.Module):
    """Channel-wise attention for 3D volumes."""
    def __init__(self, in_channels):
        super(ChannelAttentionBlock3d, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
        :param x: input( BxCxHxWxD )
        :return: affinity value + x
        """
        B, C, H, W, D = x.size()
        proj_query = x.view(B, C, -1).permute(0, 2, 1)
        proj_key = x.view(B, C, -1)
        proj_judge = x.view(B, C, -1).permute(0, 2, 1)
        affinity1 = torch.matmul(proj_key, proj_query)
        affinity2 = torch.matmul(proj_key, proj_judge)
        affinity = torch.matmul(affinity1, affinity2)
        affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity
        affinity_new = self.softmax(affinity_new)
        proj_value = x.view(B, C, -1)
        weights = torch.matmul(affinity_new, proj_value)
        weights = weights.view(B, C, H, W, D)
        out = self.gamma * weights + x
        return out


class AffinityAttention3d(nn.Module):
    """Affinity attention combining spatial and channel attention for 3D."""
    def __init__(self, in_channels):
        super(AffinityAttention3d, self).__init__()
        self.sab = SpatialAttentionBlock3d(in_channels)
        self.cab = ChannelAttentionBlock3d(in_channels)

    def forward(self, x):
        """
        sab: spatial attention block
        cab: channel attention block
        :param x: input tensor
        :return: sab + cab
        """
        sab = self.sab(x)
        cab = self.cab(x)
        out = sab + cab + x
        return out


class CSNet3D(nn.Module):
    """CSNet3D: A full encoder-decoder 3D segmentation network with affinity attention."""
    def __init__(self, classes, channels):
        """
        :param classes: the object classes number.
        :param channels: the channels of the input image.
        """
        super(CSNet3D, self).__init__()
        self.enc_input = ResEncoder3d(channels, 16)
        self.encoder1 = ResEncoder3d(16, 32)
        self.encoder2 = ResEncoder3d(32, 64)
        self.encoder3 = ResEncoder3d(64, 128)
        self.encoder4 = ResEncoder3d(128, 256)
        self.downsample = downsample()
        self.affinity_attention = AffinityAttention3d(256)
        self.attention_fuse = nn.Conv3d(256 * 2, 256, kernel_size=1)
        self.decoder4 = Decoder3d(256, 128)
        self.decoder3 = Decoder3d(128, 64)
        self.decoder2 = Decoder3d(64, 32)
        self.decoder1 = Decoder3d(32, 16)
        self.deconv4 = deconv(256, 128)
        self.deconv3 = deconv(128, 64)
        self.deconv2 = deconv(64, 32)
        self.deconv1 = deconv(32, 16)
        self.final = nn.Conv3d(16, classes, kernel_size=1)
        initialize_weights(self)

    def forward(self, x):
        enc_input = self.enc_input(x)
        down1 = self.downsample(enc_input)

        enc1 = self.encoder1(down1)
        down2 = self.downsample(enc1)

        enc2 = self.encoder2(down2)
        down3 = self.downsample(enc2)

        enc3 = self.encoder3(down3)
        down4 = self.downsample(enc3)

        input_feature = self.encoder4(down4)

        # Do Attenttion operations here
        attention = self.affinity_attention(input_feature)
        attention_fuse = input_feature + attention

        # Do decoder operations here
        up4 = self.deconv4(attention_fuse)
        up4 = torch.cat((enc3, up4), dim=1)
        dec4 = self.decoder4(up4)

        up3 = self.deconv3(dec4)
        up3 = torch.cat((enc2, up3), dim=1)
        dec3 = self.decoder3(up3)

        up2 = self.deconv2(dec3)
        up2 = torch.cat((enc1, up2), dim=1)
        dec2 = self.decoder2(up2)

        up1 = self.deconv1(dec2)
        up1 = torch.cat((enc_input, up1), dim=1)
        dec1 = self.decoder1(up1)

        final = self.final(dec1)
        final = torch.sigmoid(final)
        return final


class CSNet3DLightning(pl.LightningModule):
    """
    PyTorch LightningModule for CSNet3D training, validation, and testing.

    This module wraps the 3D segmentation model and provides complete metric tracking,
    optional tiny-structure evaluation, and Hausdorff distance measurement.

    Args:
        classes (int): Number of output classes (usually 1).
        channels (int): Number of input channels.
        lr (float): Learning rate.
        threshold (float): Threshold for converting probabilities to binary predictions.
        noise (float): Noise label for saved sample naming.
        log_to_logger (bool): Enable/disable logging to the Lightning logger.
        tiny_structures_eval (bool): Evaluate small structures separately using mask.
        model_noise (float): Distinguishing tag for saving.
        eval_hd (bool): Enable Hausdorff distance computation using MedPy.
    """
    def __init__(self, classes, channels, lr=1e-4, threshold=0.5, noise=0, log_to_logger=True,
                 model_noise=0, eval_hd=False):
        super(CSNet3DLightning, self).__init__()
        self.model = CSNet3D(classes, channels)
        self.lr = lr
        self.save_hyperparameters()
        self.validation_samples = []
        self.test_samples = []
        self.threshold = threshold
        self.criterion = nn.BCELoss()
        self.samples_path = f"./samples"
        self.test_outputs = []
        self.log_to_logger = log_to_logger
        self.eval_hd = eval_hd
        print(f"Self eval_hd: {self.eval_hd}")

        self.metrics = {
            "train": {
                "accuracy": BinaryAccuracy(threshold=threshold),
                "sensitivity": Sensitivity(threshold=threshold),
                "specificity": Specificity(num_classes=1, threshold=threshold, average='micro', task='binary'),
                "jaccard": JaccardIndex(num_classes=1, threshold=threshold, average='macro', task='binary'),
                "hausdorff_distance": HausdorffDistance(percentile=95, threshold=threshold)
            },
            "val": {
                "accuracy": BinaryAccuracy(threshold=threshold),
                "sensitivity": Sensitivity(threshold=threshold),
                "specificity": Specificity(num_classes=1, threshold=threshold, average='none', task='binary'),
                "jaccard": JaccardIndex(num_classes=1, threshold=threshold, average='none', task='binary'),
                "hausdorff_distance": HausdorffDistance(percentile=95, threshold=threshold)
            },
            "test": {
                "accuracy": BinaryAccuracy(threshold=threshold),
                "sensitivity": Sensitivity(threshold=threshold),
                "specificity": Specificity(num_classes=1, threshold=threshold, average='none', task='binary'),
                "jaccard": JaccardIndex(num_classes=1, threshold=threshold, average='none', task='binary'),
                "hausdorff_distance": HausdorffDistance(percentile=95, threshold=threshold)
            }
        }
        os.makedirs(os.path.join(self.samples_path, "validation"), exist_ok=True)
        os.makedirs(os.path.join(self.samples_path, "test"), exist_ok=True)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=0.0005)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lambda epoch: (1 - epoch / self.trainer.max_epochs) ** 0.9)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def on_fit_start(self):
        device = self.device
        for stage in self.metrics:
            for key in self.metrics[stage]:
                self.metrics[stage][key] = self.metrics[stage][key].to(device)

    def on_test_start(self):
        device = self.device
        for stage in self.metrics:
            for key in self.metrics[stage]:
                self.metrics[stage][key] = self.metrics[stage][key].to(device)

    def common_step(self, batch, metrics):
        images, targets, opened = batch
        tiny_structures_mask = torch.ones(opened.shape).to(self.device) - opened

        preds = self(images)
        loss = self.criterion(preds, targets)

        acc = metrics["accuracy"](preds, targets)
        sens = metrics["sensitivity"](preds, targets)
        spec = metrics["specificity"](preds, targets)
        jacc = metrics["jaccard"](preds, targets)
        if self.eval_hd:
            my_hd, hd_95 = metrics["hausdorff_distance"](preds, targets)
        else:
            my_hd, hd_95 = torch.tensor(-1), torch.tensor(-1)

        return images, targets, preds, tiny_structures_mask, loss, acc, sens, spec, jacc, my_hd, hd_95

    def log_metrics(self, loss, acc, sens, spec, jacc, my_hd, hd_95, prefix, on_epoch=True, tiny=False):
        log_name_pref = f"{prefix} tiny" if tiny else prefix
        self.log(f"{log_name_pref} loss", loss, on_epoch=on_epoch, on_step=not on_epoch)
        self.log(f"{log_name_pref} acc", acc, on_epoch=on_epoch, on_step=not on_epoch)
        self.log(f"{log_name_pref} sensitivity", sens, on_epoch=on_epoch, on_step=not on_epoch)
        self.log(f"{log_name_pref} specificity", spec, on_epoch=on_epoch, on_step=not on_epoch)
        self.log(f"{log_name_pref} jaccard", jacc, on_epoch=on_epoch, on_step=not on_epoch)
        if self.eval_hd:
            self.log(f"{log_name_pref} hausdorff distance", my_hd, on_epoch=on_epoch, on_step=not on_epoch)
            self.log(f"{log_name_pref} hausdorff distance 95", hd_95, on_epoch=on_epoch, on_step=not on_epoch)

    def training_step(self, batch, batch_idx):
        images, targets, preds, tiny_struc_masks, loss, acc, sens, spec, jacc, my_hd, hd_95 = self.common_step(batch,
                                                                                                 self.metrics["train"])
        if self.log_to_logger:
            self.log_metrics(loss, acc, sens, spec, jacc, my_hd, hd_95, "train", on_epoch=False)

        return loss

    def validation_step(self, batch, batch_idx):
        images, targets, preds, tiny_struc_masks, loss, acc, sens, spec, jacc, my_hd, hd_95 = self.common_step(batch,
                                                                                                 self.metrics["val"])
        if self.log_to_logger:
            self.log_metrics(loss, acc, sens, spec, jacc, my_hd, hd_95, "val", on_epoch=True)

        if batch_idx == 0:
            for i in range(images.shape[0]):
                cur_pred = preds[i, :, :, :, :].squeeze().detach().cpu()
                cur_image = images[i, :, :, :, :].squeeze().detach().cpu()
                cur_target = targets[i, :, :, :, :].squeeze().detach().cpu()

                np.save(os.path.join(self.samples_path, "validation", f"pred_{i}.npy"), cur_pred.numpy())
                np.save(os.path.join(self.samples_path, "validation", f"image_{i}.npy"), cur_image.numpy())
                np.save(os.path.join(self.samples_path, "validation", f"target_{i}.npy"), cur_target.numpy())

                # log_3d_point_cloud(cur_pred.numpy(), name=f"Val_pred_{i}")
                # log_3d_point_cloud(cur_target.numpy(), name=f"Val_target_{i}")
        return loss

    def test_step(self, batch, batch_idx):
        images, targets, preds, tiny_struc_masks, loss, acc, sens, spec, jacc, my_hd, hd_95 = self.common_step(batch,
                                                                                                 self.metrics["test"])
        if self.log_to_logger:
            self.log_metrics(loss, acc, sens, spec, jacc, my_hd, hd_95, "test", on_epoch=True)

        bin_preds = preds > self.threshold
        bin_preds = bin_preds.cpu().numpy().astype(np.uint8)
        gts = targets.cpu().numpy().astype(np.uint8)

        hausdorff_dist = []
        for pred, gt in zip(bin_preds, gts):
            hd_val = hd(pred, gt)
            hausdorff_dist.append(hd_val)

        batch_hd = np.nanmean(hausdorff_dist)
        if self.log_to_logger:
            self.log("test hausdorff distance", batch_hd, on_epoch=True, on_step=False)

        if batch_idx <= 5:
            for i in range(images.shape[0]):
                pred_np = preds[i, :, :, :, :].squeeze().detach().cpu().numpy()
                image_np = images[i, :, :, :, :].squeeze().detach().cpu().numpy()

                np.save(os.path.join(self.samples_path, "test", f"pred_{i}.npy"), pred_np)
                np.save(os.path.join(self.samples_path, "test", f"image_{i}.npy"), image_np)

        output = {
            "loss": loss,
            "acc": acc,
            "sens": sens,
            "spec": spec,
            "jacc": jacc,
            "hausdorff_distance": batch_hd,
            "my_hausdorff_distance": my_hd,
            "my_hausdorff_distance_95": hd_95
        }
        self.test_outputs.append(output)
        return loss

    def on_test_epoch_end(self):
        hd_values = [x["hausdorff_distance"] for x in self.test_outputs]
        mean_hd = np.mean(hd_values)
        var_hd = np.var(hd_values)

        losses = [x["loss"] for x in self.test_outputs]
        losses = torch.stack(losses)
        mean_loss = losses.mean().item()
        var_loss = losses.var().item()

        accs = [x["acc"] for x in self.test_outputs]
        accs = torch.stack(accs)
        mean_acc = accs.mean().item()
        var_acc = accs.var().item()

        senss = [x["sens"] for x in self.test_outputs]
        senss = torch.stack(senss)
        mean_sens = senss.mean().item()
        var_sens = senss.var().item()

        specs = [x["spec"] for x in self.test_outputs]
        specs = torch.stack(specs)
        mean_spec = specs.mean().item()
        var_spec = specs.var().item()

        jaccs = [x["jacc"] for x in self.test_outputs]
        jaccs = torch.stack(jaccs)
        mean_jacc = jaccs.mean().item()
        var_jacc = jaccs.var().item()

        my_hd = [x["my_hausdorff_distance"] for x in self.test_outputs]
        mean_my_hd = np.mean(my_hd)
        var_my_hd = np.var(my_hd)

        my_hd_95 = [x["my_hausdorff_distance_95"] for x in self.test_outputs]
        mean_my_hd_95 = np.mean(my_hd_95)
        var_my_hd_95 = np.var(my_hd_95)

        print("Test set resultes")
        print(f"    Loss: mean = {mean_loss:.3f}, var = {var_loss}")
        print(f"    Accuracy: mean = {mean_acc:.3f}, var = {var_acc}")
        print(f"    Sensitivity: mean = {mean_sens:.3f}, var = {var_sens}")
        print(f"    Specificity: mean = {mean_spec:.3f}, var = {var_spec}")
        print(f"    Jaccard: mean = {mean_jacc:.3f}, var = {var_jacc}")
        print(f"    Hausdorff distance: mean = {mean_hd:.3f}, var = {var_hd}")
        print(f"    My Hausdorff distance: mean = {mean_my_hd:.3f}, var = {var_my_hd}")
        print(f"    My Hausdorff distance 95: mean = {mean_my_hd_95:.3f}, var = {var_my_hd_95}")
        print()
