import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from utils.visualization import create_segmentation_plot, create_segmentation_plot_test
import wandb
from torchvision import transforms as T
import matplotlib.pyplot as plt
from torchmetrics import Specificity, JaccardIndex
from torchmetrics.classification import BinaryAccuracy
from utils.my_metrics import Sensitivity


def downsample():
    """
    Max pooling operation for downsampling in the encoder.
    """
    return nn.MaxPool2d(kernel_size=2, stride=2)


def deconv(in_channels, out_channels):
    """
    Transposed convolution (upsampling) for the decoder.
    """
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)


def initialize_weights(*models):
    """
    Initialize weights of given models using Kaiming initialization.
    """
    for model in models:
        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()


class ResEncoder(nn.Module):
    """
    Residual encoder block: two Conv-BN-ReLU layers + skip connection.
    """
    def __init__(self, in_channels, out_channels):
        super(ResEncoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=False)
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        residual = self.conv1x1(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = out + residual
        out = self.relu(out)
        return out


class Decoder(nn.Module):
    """
    Decoder block: two Conv-BN-ReLU layers.
    """
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=False),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=False)
        )

    def forward(self, x):
        out = self.conv(x)
        return out


class SpatialAttentionBlock(nn.Module):
    """
    Computes attention across spatial dimensions using query-key-value mechanism.
    """
    def __init__(self, in_channels):
        super(SpatialAttentionBlock, self).__init__()
        self.query = nn.Sequential(
            nn.Conv2d(in_channels,in_channels//8,kernel_size=(1,3), padding=(0,1)),
            nn.BatchNorm2d(in_channels//8),
            nn.ReLU(inplace=False)
        )
        self.key = nn.Sequential(
            nn.Conv2d(in_channels, in_channels//8, kernel_size=(3,1), padding=(1,0)),
            nn.BatchNorm2d(in_channels//8),
            nn.ReLU(inplace=False)
        )
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, C, H, W = x.size()
        # compress x: [B,C,H,W]-->[B,H*W,C], make a matrix transpose
        proj_query = self.query(x).view(B, -1, W * H).permute(0, 2, 1)
        proj_key = self.key(x).view(B, -1, W * H)
        affinity = torch.matmul(proj_query, proj_key)
        affinity = self.softmax(affinity)
        proj_value = self.value(x).view(B, -1, H * W)
        weights = torch.matmul(proj_value, affinity.permute(0, 2, 1))
        weights = weights.view(B, C, H, W)
        out = self.gamma * weights + x
        return out


class ChannelAttentionBlock(nn.Module):
    """
    Computes attention across channel dimensions.
    """
    def __init__(self, in_channels):
        super(ChannelAttentionBlock, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, C, H, W = x.size()
        proj_query = x.view(B, C, -1)
        proj_key = x.view(B, C, -1).permute(0, 2, 1)
        affinity = torch.matmul(proj_query, proj_key)
        affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity
        affinity_new = self.softmax(affinity_new)
        proj_value = x.view(B, C, -1)
        weights = torch.matmul(affinity_new, proj_value)
        weights = weights.view(B, C, H, W)
        out = self.gamma * weights + x
        return out


class AffinityAttention(nn.Module):
    """
    Combines spatial and channel attention for enhanced feature representation.
    """
    def __init__(self, in_channels):
        super(AffinityAttention, self).__init__()
        self.sab = SpatialAttentionBlock(in_channels)
        self.cab = ChannelAttentionBlock(in_channels)
        # self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)

    def forward(self, x):
        """
        sab: spatial attention block
        cab: channel attention block
        :param x: input tensor
        :return: sab + cab
        """
        sab = self.sab(x)
        cab = self.cab(x)
        out = sab + cab
        return out


class CSNet(nn.Module):
    def __init__(self, classes, channels):
        """
        Full CSNet architecture with residual encoders, dual attention, and decoder path.
        """
        super(CSNet, self).__init__()
        self.enc_input = ResEncoder(channels, 32)
        self.encoder1 = ResEncoder(32, 64)
        self.encoder2 = ResEncoder(64, 128)
        self.encoder3 = ResEncoder(128, 256)
        self.encoder4 = ResEncoder(256, 512)
        self.downsample = downsample()
        self.affinity_attention = AffinityAttention(512)
        self.attention_fuse = nn.Conv2d(512 * 2, 512, kernel_size=1)
        self.decoder4 = Decoder(512, 256)
        self.decoder3 = Decoder(256, 128)
        self.decoder2 = Decoder(128, 64)
        self.decoder1 = Decoder(64, 32)
        self.deconv4 = deconv(512, 256)
        self.deconv3 = deconv(256, 128)
        self.deconv2 = deconv(128, 64)
        self.deconv1 = deconv(64, 32)
        self.final = nn.Conv2d(32, classes, kernel_size=1)
        initialize_weights(self)

    def forward(self, x):
        enc_input = self.enc_input(x)
        down1 = self.downsample(enc_input)

        enc1 = self.encoder1(down1)
        down2 = self.downsample(enc1)

        enc2 = self.encoder2(down2)
        down3 = self.downsample(enc2)

        enc3 = self.encoder3(down3)
        down4 = self.downsample(enc3)

        input_feature = self.encoder4(down4)

        # Do Attenttion operations here
        attention = self.affinity_attention(input_feature)

        # attention_fuse = self.attention_fuse(torch.cat((input_feature, attention), dim=1))
        attention_fuse = torch.add(input_feature, attention)

        # Do decoder operations here
        up4 = self.deconv4(attention_fuse)
        up4 = torch.cat((enc3, up4), dim=1)
        dec4 = self.decoder4(up4)

        up3 = self.deconv3(dec4)
        up3 = torch.cat((enc2, up3), dim=1)
        dec3 = self.decoder3(up3)

        up2 = self.deconv2(dec3)
        up2 = torch.cat((enc1, up2), dim=1)
        dec2 = self.decoder2(up2)

        up1 = self.deconv1(dec2)
        up1 = torch.cat((enc_input, up1), dim=1)
        dec1 = self.decoder1(up1)

        final = self.final(dec1)
        final = F.sigmoid(final)
        return final


class CSNetLightning(pl.LightningModule):
    """
    PyTorch Lightning wrapper for CSNet training, validation, and testing.

    Args:
        classes (int): Number of output classes.
        channels (int): Number of input image channels.
        lr (float): Learning rate.
        threshold (float): Threshold for binary mask conversion.
    """
    def __init__(self, classes, channels, lr=1e-4, threshold=0.5):
        super().__init__()
        self.model = CSNet(classes, channels)
        self.lr = lr
        self.criterion = nn.BCELoss()
        self.save_hyperparameters()
        self.validation_samples = None
        self.test_samples = []
        self.threshold = threshold

        self.train_accu = BinaryAccuracy(threshold=threshold)
        self.val_accu = BinaryAccuracy(threshold=threshold)

        self.train_sensitivity = Sensitivity(threshold=threshold)
        self.val_sensitivity = Sensitivity(threshold=threshold)

        self.train_specificity = Specificity(num_classes=1, threshold=threshold, average='macro', task='binary')
        self.val_specificity = Specificity(num_classes=1, threshold=threshold, average='macro', task='binary')

        self.train_jaccard = JaccardIndex(num_classes=1, threshold=threshold, average='macro', task='binary')
        self.val_jaccard = JaccardIndex(num_classes=1, threshold=threshold, average='macro',  task='binary')

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=0.0005)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: (1 - epoch / self.trainer.max_epochs) ** 0.9)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def training_step(self, batch, batch_idx):
        image, target = batch
        preds = self(image)
        loss = self.criterion(preds, target)
        accuracy = self.train_accu(preds, target)
        sensitivity = self.train_sensitivity(preds, target)
        specificity = self.train_specificity(preds, target)
        iou = self.train_jaccard(preds, target)
        self.log('train loss', loss, prog_bar=True)
        self.log('train acc', accuracy, prog_bar=True)
        self.log('train sen', sensitivity, prog_bar=True)
        self.log('train spe', specificity, prog_bar=True)
        self.log('train iou', iou, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        preds = self(images)
        loss = self.criterion(preds, targets)
        accuracy = self.val_accu(preds, targets)
        sensitivity = self.val_sensitivity(preds, targets)
        specificity = self.val_specificity(preds, targets)
        iou = self.val_jaccard(preds, targets)

        self.log('val loss', loss, prog_bar=True)
        self.log('val acc', accuracy, prog_bar=True)
        self.log('val sen', sensitivity, prog_bar=True)
        self.log('val spe', specificity, prog_bar=True)
        self.log('val iou', iou, prog_bar=True)

        if batch_idx == 0:
            self.validation_samples = (images[:4], targets[:4], preds[:4])
        return loss

    def on_validation_epoch_end(self):
        """Logs original images, predicted masks, and ground truth masks after each epoch."""
        if self.validation_samples is None:
            return  # Skip if no images were stored

        images, masks, preds = self.validation_samples
        thresholded_preds = (preds > self.threshold).float()  # Threshold to binary mask

        transform = T.ToPILImage()
        log_images = []

        for i in range(len(images)):
            img = transform(images[i].cpu())  # Convert tensor to image
            pred_mask = preds[i].squeeze(0).cpu().numpy()
            thresholded_mask = thresholded_preds[i].squeeze(0).cpu().numpy()
            gt_mask = masks[i].squeeze(0).cpu().numpy()

            fig = create_segmentation_plot(img, gt_mask, pred_mask, thresholded_mask)

            log_image = wandb.Image(fig, caption=f"Sample {i + 1}")
            log_images.append(log_image)

            plt.close(fig)  # Close the figure to prevent memory leaks

        wandb.log({"Validation Predictions": log_images})

    def test_step(self, batch, batch_idx):
        images = batch
        preds = self(images)
        thresholded_preds = (preds > self.threshold).float()
        for i in range(len(images)):
            self.test_samples.append(
                (images[i].cpu(), preds[i].squeeze(0).cpu().numpy(),
                 thresholded_preds[i].squeeze(0).cpu().numpy())
            )

    def on_test_epoch_end(self):
        """Saves and logs test images with predicted masks."""
        if not self.test_samples:
            return  # No test samples found

        transform = T.ToPILImage()
        log_images = []

        for i, (img_tensor, pred_mask, thresholded_mask) in enumerate(self.test_samples):
            img = transform(img_tensor)  # Convert tensor to PIL Image

            fig = create_segmentation_plot_test(img, pred_mask, thresholded_mask)

            log_image = wandb.Image(fig, caption=f"Test Sample {i + 1}")
            log_images.append(log_image)

            plt.close(fig)

        wandb.log({"Test Predictions": log_images})

