"""
Train script for CSNet (2D version) on the DRIVE dataset for vessel segmentation.

This script configures and trains the CSNet model using PyTorch Lightning.
Training results and model checkpoints are logged with Weights & Biases (wandb).
"""
import os
import torch
import wandb
import argparse
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from dataloader.drive import DriveDataset
from model.csnet import CSNetLightning


def parse_args():
    """
    Parses command-line arguments for training configuration.

    Returns:
        argparse.Namespace: Parsed arguments
    """
    parser = argparse.ArgumentParser(description="Train 2D CSNet model on DRIVE dataset")
    parser.add_argument('--data_path', type=str,
                        default='/home/xmoravc/DP/data/Drive_dataset/',
                        help='Path to the DRIVE dataset directory')
    parser.add_argument('--epochs', type=int, default=750,
                        help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=0.0001,
                        help='Learning rate')
    parser.add_argument('--batch_size', type=int, default=1,
                        help='Batch size')
    parser.add_argument('--ckpt_path', type=str, default='checkpoint_tnts/',
                        help='Path to save model checkpoints')
    parser.add_argument('--accelerator', type=str, default='gpu',
                        help='Training device: "gpu" or "cpu"')
    parser.add_argument('--devices', type=int, default=1,
                        help='Number of devices (e.g., GPUs) to use')
    return parser.parse_args()


def main():
    """
    Main training loop for 2D CSNet model.
    Initializes dataloaders, model, trainer, and wandb logging.
    """
    args = parse_args()

    # Check for GPU
    print("Cuda available:", torch.cuda.is_available())

    # Init wandb
    wandb.init(project='AI_project')

    # Ensure checkpoint path exists
    os.makedirs(args.ckpt_path, exist_ok=True)

    # Enable anomaly detection for debugging unstable gradients
    torch.autograd.set_detect_anomaly(True)

    # Load data module
    data_module = DriveDataset(
        root_dir=args.data_path,
        batch_size=args.batch_size,
        num_workers=4
    )

    # Initialise CSNet model
    model = CSNetLightning(
        classes=1,
        channels=3,  # RGB input for DRIVE dataset
        lr=args.lr
    )

    # Setup wandb logger
    logger = WandbLogger(project='AI_project', log_model=True)
    wandb_run_name = wandb.run.name
    checkpoint_filename = f"{wandb_run_name}"

    # Callback to save the best model
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.ckpt_path,
        filename=checkpoint_filename,
        monitor='val loss',
        mode='min',
        save_top_k=1,
    )

    # Monitor learning rate
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    # Setup trainer
    trainer = pl.Trainer(
        max_epochs=args.epochs,
        accelerator=args.accelerator,
        devices=args.devices,
        logger=logger,
        callbacks=[checkpoint_callback, lr_monitor],
        log_every_n_steps=1
    )

    # Run training and testing
    trainer.fit(model, data_module)
    trainer.test(model, datamodule=data_module)

    # Finish wandb session
    wandb.finish()


if __name__ == '__main__':
    main()


