import os
from skimage import io

import matplotlib.pyplot as plt
import nibabel as nib
import torch
import numpy as np
import wandb
import napari
from tifffile import imread
def create_segmentation_plot(image, gt_mask, pred_mask, thresholded_mask):
    """
    Creates a 4-panel matplotlib figure comparing the original image,
    ground truth mask, model output, and thresholded output.

    :param image: Original image (PIL Image or NumPy array)
    :param gt_mask: Ground truth mask (NumPy array)
    :param pred_mask: Model output (NumPy array)
    :param thresholded_mask: Thresholded prediction mask (NumPy array)
    :return: Matplotlib figure
    """
    fig, axes = plt.subplots(1, 4, figsize=(12, 4))

    axes[0].imshow(image)  # Original image
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    axes[1].imshow(gt_mask, cmap='gray')  # Ground truth mask
    axes[1].set_title("Ground Truth Mask")
    axes[1].axis("off")

    axes[2].imshow(pred_mask, cmap='gray')  # Model output
    axes[2].set_title("Model Output")
    axes[2].axis("off")

    axes[3].imshow(thresholded_mask, cmap='gray')  # Thresholded output
    axes[3].set_title("Thresholded Output")
    axes[3].axis("off")

    fig.tight_layout()
    return fig

def create_segmentation_plot_test(image, pred_mask, thresholded_mask):
    """
    Creates a 3-panel matplotlib figure comparing:
    1. Original image
    2. Predicted mask
    3. Thresholded mask

    :param image: Original image (PIL Image or NumPy array)
    :param pred_mask: Model output mask (NumPy array)
    :param thresholded_mask: Thresholded prediction mask (NumPy array)
    :return: Matplotlib figure
    """
    fig, axes = plt.subplots(1, 3, figsize=(10, 4))

    axes[0].imshow(image)  # Original image
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    axes[1].imshow(pred_mask, cmap='gray')  # Model output
    axes[1].set_title("Predicted Mask")
    axes[1].axis("off")

    axes[2].imshow(thresholded_mask, cmap='gray')  # Thresholded output
    axes[2].set_title("Thresholded Mask")
    axes[2].axis("off")

    return fig


def log_3d_point_cloud(volume, name="3D_PointCloud"):
    """
    Logs a 3D volume as a point cloud to WandB.
    :param volume: (D, H, W) PyTorch tensor or NumPy array.
    """
    volume = volume.cpu().numpy() if isinstance(volume, torch.Tensor) else volume

    # Extract points where voxel intensity is above a threshold
    threshold = 0.5  # Adjust as needed
    points = np.argwhere(volume > threshold)

    if points.shape[0] == 0:
        print(f"Warning: No points above threshold in {name}. Skipping logging.")
        return

    # Convert `points` to a proper NumPy array
    points = np.array(points, dtype=np.float32)  # Ensure NumPy format

    # Create a WandB-compatible point cloud dictionary
    point_cloud = {
        "type": "lidar/beta",
        "points": points  # FIX: Ensure this is a NumPy array, not a list
    }

    wandb.log({name: wandb.Object3D(point_cloud)})

def show_nii(root="~/DP/data/syn_3d.py/extracted_data/", filename="1.nii.gz"):
    img = nib.load(os.path.join(root, "raw", filename))
    mask = nib.load(os.path.join(root, "seg", filename))
    data_img = img.get_fdata()
    data_mask = mask.get_fdata()

    # Normalize for visualization (Optional, scales data to [0,1])
    data_img = (data_img - data_img.min()) / (data_img.max() - data_img.min())
    data_mask = (data_mask - data_mask.min()) / (data_mask.max() - data_mask.min())

    # Open in Napari
    viewer = napari.Viewer()
    viewer.add_image(data_img, colormap="gray", name="NIfTI Image")
    viewer.add_image(data_mask, colormap="inferno", name="NIfTI Mask")

    # Start Napari GUI
    napari.run()


def show_preds(idx=0, path="C:/muni/DP/seg1/Segment1/samples/validation", gt=False):
    image = np.load(f"{path}/image_{idx}.npy")
    pred = np.load(f"{path}/pred_{idx}.npy")


    # thresholded = pred > 0.5

    viewer = napari.Viewer()
    viewer.add_image(image, name="Image")
    viewer.add_image(pred, name="Prediction")

    if gt:
        target = np.load(f"{path}/target_{idx}.npy")
        viewer.add_image(target, name="Ground Truth")
    # viewer.add_image(thresholded, name="Thresholded Prediction")

    napari.run()


if __name__ == "__main__":
    path = "C:/muni/DP/seg1/Segment1/TNTs_samples/validation"
    show_preds(0, path=path, gt=True)


    # show_nii(root="C:/muni/DP/3d_dataset", filename="1.nii.gz")
    # show_preds(0, path="C:/muni/DP/seg1/Segment1/samples/validation/")
    # path = "C:/muni/DP/seg1/Segment1/tnt_preds_35/161223_ptcG4_x_mito-mCh_5"
    # path = "C:/muni/DP/seg1/Segment1/tnt_preds_25/180322_Sqh-mCh Tub-GFP 16h_110.tif.files"
    # path = "C:/muni/DP/seg1/Segment1/tnt_preds_25/mitoRoundtrip2layers"
    # path = "C:/muni/DP/seg1/Segment1/tnt_preds_25/mitoRoundtripBundling"
    # pred = imread(os.path.join(path, "pred.npy"))
    # img = imread(os.path.join(path, "image.npy"))
    # print(f"Prediction shape: {pred.shape}")
    # pred = pred.squeeze(0)
    # img = img.squeeze(0)
    # viewer = napari.Viewer()
    # viewer.add_image(img, name="Image")
    # viewer.add_image(pred, name="Prediction")
    # napari.run()

    # image1_path = "C:/muni/DP/180322_Sqh-mCh_Tub-GFP_16h_110/180322_Sqh-mCh Tub-GFP 16h_110/01/t000.tif"
    # image2_path = "C:/muni/DP/180322_Sqh-mCh_Tub-GFP_16h_110/180322_Sqh-mCh Tub-GFP 16h_110/01/t017.tif"
    #
    # mask1_path = "C:/muni/DP/180322_Sqh-mCh_Tub-GFP_16h_110/180322_Sqh-mCh Tub-GFP 16h_110/01_GT/SEG/mask000.tif"
    # mask2_path = "C:/muni/DP/180322_Sqh-mCh_Tub-GFP_16h_110/180322_Sqh-mCh Tub-GFP 16h_110/01_GT/SEG/mask017.tif"

    # load images and masks and visualize them using Napari

    # image1 = io.imread(image1_path)
    # image2 = io.imread(image2_path)
    # mask1 = io.imread(mask1_path)
    # mask2 = io.imread(mask2_path)
    #
    # viewer = napari.Viewer()
    # viewer.add_image(image1, name="Image 1")
    # viewer.add_labels(mask1, name="Mask 1")
    # viewer.add_image(image2, name="Image 2")
    # viewer.add_labels(mask2, name="Mask 2")
    #
    # napari.run()


# def show_tnt_pred():
#     # load npy from ./tnt_2d_prediction.npy
#     pred = np.load("C:/muni/DP/CS2Net/CS-Net/tnt_2d_prediction.npy")
#     image = io.imread("C:/muni/DP/TNT_data/180322_Sqh-mCh Tub-GFP 16h_110.tif.files/s_C001Z001T001.tif")
#     # print(f"image type: {type(image)}, image shape: {image.shape}, image dtype: {image.dtype}, max: {image.max()}, min: {image.min()}")
#     image = image.astype(np.float32)
#     image = (image - image.min()) / (image.max() - image.min())
#     viewer = napari.Viewer()
#     viewer.add_image(image, name="Image")
#     viewer.add_image(pred, name="Prediction")
#     napari.run()
#
# if __name__ == "__main__":
#     show_tnt_pred()