import os

import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision.transforms import Resize
from tifffile import imsave, imwrite

from model.csnet_3d import CSNet3DLightning
from dataloader.syn_3d import Syn3DLightning

model_noises = {
    "checkpoint_3d/volcanic-moon-18.ckpt": 0,
    "checkpoint_3d/olive-field-19.ckpt": 25,
    "checkpoint_3d/wise-paper-20.ckpt": 35,
    "checkpoint_3d/ethereal-aardvark-21.ckpt": 45,
    "checkpoint_3d/dry-grass-24.ckpt": -1,
    "checkpoint_3d/deft-shadow-25.ckpt": -1,
    "checkpoint_3d/neat-oath-26.ckpt": -1
}



args = {#"checkpoint_path": "checkpoint_3d/volcanic-moon-18.ckpt",
#         "checkpoint_path": "checkpoint_3d/olive-field-19.ckpt",
#         "checkpoint_path": "checkpoint_3d/wise-paper-20.ckpt",
#         "checkpoint_path": "checkpoint_3d/ethereal-aardvark-21.ckpt",
#         "checkpoint_path": "checkpoint_3d/dry-grass-24.ckpt",
#         "checkpoint_path": "checkpoint_3d/deft-shadow-25.ckpt",
        "checkpoint_path": "checkpoint_3d/neat-oath-26.ckpt",
        "data_path": "/home/xmoravc/DP/data/syn_3d/extracted_data/",
        # "data_path": "/home/xmoravc/DP/data/TNT_data_3d/161223_ptcG4_x_mito-mCh_5",
        # "data_path": "/home/xmoravc/DP/data/TNT_data_3d/180322_Sqh-mCh Tub-GFP 16h_110.tif.files",
        # "data_path": "/home/xmoravc/DP/data/TNT_data_3d/mitoRoundtrip2layers",
        # "data_path": "/home/xmoravc/DP/data/TNT_data_3d/mitoRoundtripBundling",
        "accelerator": "gpu",
        "devices": 1,
        "add_noise": 0.45
        }



if __name__ == "__main__":
    mod_noise = model_noises[args["checkpoint_path"]]
    print("________________________________________________________________________")
    print(f"Model noise: {mod_noise}, checkpoint: {args['checkpoint_path']}, data noise {args['add_noise']}")
    print("________________________________________________________________________")
    model = CSNet3DLightning(classes=1, channels=1, lr=0.0001, noise=args["add_noise"],
                             log_to_logger=False, tiny_structures_eval=True,
                             model_noise=mod_noise, eval_hd=True)
    model.load_state_dict(torch.load(args["checkpoint_path"])["state_dict"])

    #
    # model = CSNet3DLightning.load_from_checkpoint(args["checkpoint_path"])
    # # model.to("cuda")
    # model.log_to_logger = False
    # model.tiny_structures_eval = True
    model.eval()

    data = Syn3DLightning(args["data_path"], batch_size=1, num_workers=4, add_noise=args["add_noise"])
    data.setup("test")

    trainer = pl.Trainer(accelerator=args["accelerator"], devices=args["devices"])

    test_results = trainer.test(model, datamodule=data)

    # paths = ["/home/xmoravc/DP/data/TNT_data_3d/161223_ptcG4_x_mito-mCh_5",
    #         "/home/xmoravc/DP/data/TNT_data_3d/180322_Sqh-mCh Tub-GFP 16h_110.tif.files",
    #         "/home/xmoravc/DP/data/TNT_data_3d/mitoRoundtrip2layers",
    #         "/home/xmoravc/DP/data/TNT_data_3d/mitoRoundtripBundling",]
    # for path in paths:
    #     args["data_path"] = path
    #     try_on_tnt(path)
