"""
Train script for CSNet3D: 3D segmentation network with channel and spatial attention.

This script handles argument parsing, data loading, model training, and evaluation using
PyTorch Lightning. The model and data are logged to Weights & Biases (wandb).
"""
import os
import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from dataloader.syn_3d import Syn3DLightning
import argparse
from model.csnet_3d import CSNet3DLightning


def parse_args():
    parser = argparse.ArgumentParser(description="Train 3D CSNet Model")
    parser.add_argument('--data_path', type=str, default='/home/xmoravc/DP/data/syn_3d/extracted_data',
                        help='Path to the dataset')
    parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
    parser.add_argument('--ckpt_path', type=str, default='checkpoint', help='Path to save checkpoints')
    parser.add_argument('--accelerator', type=str, default='gpu', help='Device type: gpu or cpu')
    parser.add_argument('--devices', type=int, default=1, help='Number of devices')
    parser.add_argument('--add_noise', type=float, default=0, help='Add noise to training data (0 = no noise)')
    return parser.parse_args()


def main():
    """
    Main training loop for CSNet3D using PyTorch Lightning.
    """
    args = parse_args()

    # Initialise wandb run
    wandb.init(project='AI_project')

    # Create checkpoint directory if it doesn't exist
    os.makedirs(args.ckpt_path, exist_ok=True)

    # Enable anomaly detection in autograd for debugging
    torch.autograd.set_detect_anomaly(True)

    # Initialise data module
    data_module = Syn3DLightning(args.data_path, batch_size=args.batch_size, num_workers=4, add_noise=args.add_noise)
    data_module.setup()

    # Initialise model
    model = CSNet3DLightning(classes=1, channels=1, lr=args.lr)

    # Wandb logger with auto-named run
    logger = WandbLogger(project='AI_project', log_model=True)
    wandb_run_name = wandb.run.name
    print(f"Run name: {wandb_run_name}")
    checkpoint_filename = f'{wandb_run_name}'

    # Checkpoint callback for saving best model
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=args.ckpt_path,
        filename=checkpoint_filename,
        monitor='val loss',
        mode='min',
        save_top_k=1,
    )

    # Learning rate scheduler monitor
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    #  PyTorch Lightning trainer setup
    trainer = pl.Trainer(
        max_epochs=args.epochs,
        accelerator=args.accelerator,
        devices=args.devices,
        logger=logger,
        callbacks=[checkpoint_callback, lr_monitor],
        log_every_n_steps=1
    )

    # Train and evaluate
    trainer.fit(model, data_module)
    trainer.test(model, datamodule=data_module)

    # Finalize wandb run
    wandb.finish()


if __name__ == '__main__':
    main()
