#!/usr/bin/env python3
import argparse
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image, ImageFile
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import Dataset
from transformers import (
    AutoModel,
    AutoTokenizer,
    BertModel,
    BertTokenizer,
    ModernBertModel,
    DebertaModel,
    DebertaTokenizer,
    ViTModel,
    SwinModel,
    Trainer,
    TrainingArguments
)
import logging
from safetensors.torch import save_file

# Ensure truncated images can be loaded
ImageFile.LOAD_TRUNCATED_IMAGES = True


class TextClassifier(nn.Module):
    """
    Text classification model class
    """

    def __init__(self, text_model_name, num_labels, proj_dim=224, dropout_prob=0.1):
        super(TextClassifier, self).__init__()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if 'deberta' in text_model_name.lower():
            # Deberta model does not support auto device mapping
            self.text_model = DebertaModel.from_pretrained(
                text_model_name).to(device)
        elif 'modernbert' in text_model_name.lower():
            self.text_model = ModernBertModel.from_pretrained(
                text_model_name, device_map="auto")
        elif text_model_name == 'google-bert/bert-base-uncased':
            self.text_model = BertModel.from_pretrained(
                text_model_name, device_map="auto")
        else:
            # try to load any other model, but this might not work
            self.text_model = AutoModel.from_pretrained(
                text_model_name).to(device)

        text_hidden_size = self.text_model.config.hidden_size
        self.proj = nn.Linear(text_hidden_size, proj_dim)
        self.dropout = nn.Dropout(dropout_prob)
        self.loss_fct = nn.CrossEntropyLoss()
        self.classifier = nn.Sequential(
            nn.Linear(proj_dim, 224),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(224, num_labels)
        )

        logging.info(
            f'Loaded {num_labels}-way labels\nText encoder: {text_model_name}')

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.text_model(
            input_ids=input_ids, attention_mask=attention_mask)
        # Extract the [CLS] token output.
        cls = outputs.last_hidden_state[:, 0, :]
        emb = self.proj(cls)
        emb = self.dropout(emb)
        logits = self.classifier(emb)

        loss = None
        if labels is not None:
            loss = self.loss_fct(logits, labels)
        return {"loss": loss, "logits": logits} if loss is not None else logits


class ImageClassifier(nn.Module):
    """
    Image classification model class
    """

    def __init__(self, image_model_name, num_labels, proj_dim=224, dropout_prob=0.1):
        super(ImageClassifier, self).__init__()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_model_name = image_model_name

        if "vit" in image_model_name.lower():
            self.image_model = ViTModel.from_pretrained(
                image_model_name, device_map="auto")
            image_hidden_size = self.image_model.config.hidden_size
        elif "swin" in image_model_name.lower():
            self.image_model = SwinModel.from_pretrained(
                image_model_name, device_map="auto")
            image_hidden_size = self.image_model.config.hidden_size
        elif "resnet" in image_model_name.lower():
            # If using a ResNet from Hugging Face, you can load it via AutoModel.
            self.image_model = AutoModel.from_pretrained(
                image_model_name, device_map="auto")
            image_hidden_size = self.image_model.config.hidden_sizes[-1]
        else:
            self.image_model = AutoModel.from_pretrained(
                image_model_name).to(device)
            image_hidden_size = self.image_model.config.hidden_size

        self.proj = nn.Linear(image_hidden_size, proj_dim)
        self.dropout = nn.Dropout(dropout_prob)
        self.loss_fct = nn.CrossEntropyLoss()
        self.classifier = nn.Sequential(
            nn.Linear(proj_dim, 224),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(224, num_labels)
        )

        logging.info(
            f'Loaded {num_labels}-way labels\nImage encoder: {image_model_name}')

    def forward(self, pixel_values=None, labels=None, **kwargs):
        outputs = self.image_model(pixel_values=pixel_values)

        if "resnet" not in self.image_model_name.lower():
            # Image branch: assume the model outputs a last_hidden_state with a CLS token.
            cls = outputs.last_hidden_state[:, 0, :]
        else:
            cls = torch.mean(outputs.last_hidden_state, dim=[2, 3])

        emb = self.proj(cls)
        emb = self.dropout(emb)
        logits = self.classifier(emb)

        loss = None
        if labels is not None:
            loss = self.loss_fct(logits, labels)
        return {"loss": loss, "logits": logits} if loss is not None else logits


class MultiModalClassifier(nn.Module):
    """
    Multimodal classification model class
    """

    def __init__(
        self,
        text_model_name,
        image_model_name,
        num_labels,
        hidden_dim=224,
        dropout_prob=0.1
    ):
        super(MultiModalClassifier, self).__init__()
        device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        if 'deberta' in text_model_name.lower():
            # Deberta model does not support auto device mapping
            self.text_model = DebertaModel.from_pretrained(
                text_model_name).to(device)
        elif 'modernbert' in text_model_name.lower():
            self.text_model = ModernBertModel.from_pretrained(
                text_model_name, device_map="auto")
        elif text_model_name == 'google-bert/bert-base-uncased':
            self.text_model = BertModel.from_pretrained(
                text_model_name, device_map="auto")
        else:
            # try to load any other model, but this might not work
            self.text_model = AutoModel.from_pretrained(
                text_model_name).to(device)

        text_hidden_size = self.text_model.config.hidden_size

        self.text_proj = nn.Linear(text_hidden_size, hidden_dim)

        self.image_model_name = image_model_name

        if "vit" in image_model_name.lower():
            self.image_model = ViTModel.from_pretrained(
                image_model_name, device_map="auto")
            image_hidden_size = self.image_model.config.hidden_size
        elif "swin" in image_model_name.lower():
            self.image_model = SwinModel.from_pretrained(
                image_model_name, device_map="auto")
            image_hidden_size = self.image_model.config.hidden_size
        elif "resnet" in image_model_name.lower():
            self.image_model = AutoModel.from_pretrained(
                image_model_name, device_map="auto")
            image_hidden_size = self.image_model.config.hidden_sizes[-1]
        else:
            # try to load any other model, but this might not work
            self.image_model = AutoModel.from_pretrained(
                image_model_name).to(device)
            image_hidden_size = self.image_model.config.hidden_size

        self.image_proj = nn.Linear(image_hidden_size, hidden_dim)

        self.loss_fct = nn.CrossEntropyLoss()
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 224),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(224, num_labels)
        )

        logging.info(
            f'Loaded {num_labels}-way labels\nText encoder: {text_model_name}\nImage encoder: {image_model_name}')

    def forward(self, input_ids, attention_mask, pixel_values, labels=None):
        # Text branch: extract [CLS] token and project it
        text_outputs = self.text_model(
            input_ids=input_ids, attention_mask=attention_mask)
        text_cls = text_outputs.last_hidden_state[:, 0, :]  # [CLS] token
        text_emb = self.text_proj(text_cls)

        image_outputs = self.image_model(pixel_values=pixel_values)

        if "resnet" in self.image_model_name.lower():
            image_cls = torch.mean(image_outputs.last_hidden_state, dim=[2, 3])
        else:
            # Assume the model outputs a last_hidden_state with a CLS token.
            image_cls = image_outputs.last_hidden_state[:, 0, :]

        image_emb = self.image_proj(image_cls)

        fused = torch.max(text_emb, image_emb)
        fused = self.dropout(fused)
        logits = self.classifier(fused)

        loss = None
        if labels is not None:
            loss = self.loss_fct(logits, labels)

        return {"loss": loss, "logits": logits} if loss is not None else logits


class FakedditDataset(Dataset):
    """
    PyTorch dataset class
    """

    def __init__(self, dataframe, tokenizer, image_transform, max_length=128):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.image_transform = image_transform
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        encoded_text = self.tokenizer.encode_plus(
            row["Title"],
            add_special_tokens=True,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        image = Image.open(row["Image"]).convert("RGB")
        pixel_values = self.image_transform(image).float()

        label = torch.tensor(row['Label'], dtype=torch.long)

        return {
            'input_ids': encoded_text['input_ids'].squeeze(0),
            'attention_mask': encoded_text['attention_mask'].squeeze(0),
            'pixel_values': pixel_values,
            'labels': label
        }


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average="macro", zero_division=0)
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


def get_datasets(df):
    """
    Returns train, test, validation datasets splits based on dataset
    """
    train_df, val_test_df = train_test_split(
        df, test_size=0.2, random_state=42)
    val_df, test_df = train_test_split(
        val_test_df, test_size=0.5, random_state=42)

    return train_df, val_df, test_df


def main():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
    )

    parser = argparse.ArgumentParser(
        description='Train fake news detection models',
        epilog='Expects fakeddit_with_paths.tsv or ifnd_with_paths.csv (or both for 2-way classification) in the current directory.\n',
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument('--model_type', type=str, required=True,
                        choices=['text', 'image', 'multimodal'],
                        help='Model type to train: text, image, or multimodal')
    parser.add_argument('--text_model', type=str, default='google-bert/bert-base-uncased',
                        help='Text model name in the huggingface format (e.g. google-bert/bert-base-uncased) - recommended choices include DeBERTa, BERT, ModernBERT')
    parser.add_argument('--image_model', type=str, default='microsoft/resnet-50',
                        help='Image model name in the huggingface forma (e.g. microsoft/resnet-50) - recommended choices include ViT, Swin, ResNet')
    parser.add_argument('--image_model_input_dim', type=int, default=224,
                        help='Input dimensions for the image model (default: 224)')
    parser.add_argument('--output_dir', type=str, default='./results',
                        help='Directory to save model and results')
    parser.add_argument('--num_labels', type=int, default=2, choices=[2, 6],
                        help='Number of classification labels (2 or 6)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='Training batch size')
    parser.add_argument('--epochs', type=int, default=5,
                        help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=2e-5,
                        help='Learning rate')

    args = parser.parse_args()

    image_transform = transforms.Compose([
        transforms.Resize((args.image_model_input_dim,
                          args.image_model_input_dim)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # 6-way labels dataset must be read in either case
    df = pd.read_csv('fakeddit_with_paths.tsv',
                     delimiter='\t', encoding='ISO-8859-1')
    df = df[['id', 'clean_title', 'image_url', '2_way_label', '6_way_label']]
    df = df.dropna()
    df = df.rename(columns={'clean_title': 'Title', 'image_url': 'Image',
                   '2_way_label': 'Label2', '6_way_label': 'Label6'})

    if args.num_labels == 2:
        # 2-way labels dataset
        df_2way = pd.read_csv('ifnd_with_paths.csv',
                              delimiter=',', encoding='windows-1252')
        df_2way = df_2way[['id', 'Statement', 'Image', 'Label']]
        df_2way = df_2way.dropna()
        df_2way = df_2way.rename(
            columns={'Statement': 'Title', 'Label': 'Label2'})
        df_2way['Label2'] = df_2way['Label2'].replace({'TRUE': 1, 'Fake': 0})
        df_sample_temp = df.sample(min(20000, df.shape[0]), random_state=42)
        df_sample_temp = df_sample_temp[['id', 'Title', 'Image', 'Label2']]
        df_2way = pd.concat([df_2way, df_sample_temp], ignore_index=True)
        df = df_2way.rename(columns={'Label2': 'Label'})
    else:
        df = df.rename(columns={'Label6': 'Label'})

    train_df, val_df, _ = get_datasets(df)

    # Load tokenizer based on text model
    if 'deberta' in args.text_model.lower():
        tokenizer = DebertaTokenizer.from_pretrained(args.text_model)
    elif 'bert' in args.text_model.lower() and 'modernbert' not in args.text_model.lower():
        tokenizer = BertTokenizer.from_pretrained(args.text_model)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.text_model)

    train_dataset = FakedditDataset(train_df, tokenizer, image_transform)
    val_dataset = FakedditDataset(val_df, tokenizer, image_transform)

    # Set up training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size*2,
        dataloader_pin_memory=True,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy='epoch',
        learning_rate=args.learning_rate,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        report_to='none'
    )

    # Create model based on model type
    if args.model_type == 'text':
        model = TextClassifier(
            text_model_name=args.text_model,
            num_labels=args.num_labels,
            proj_dim=224,
            dropout_prob=0.1
        )
    elif args.model_type == 'image':
        model = ImageClassifier(
            image_model_name=args.image_model,
            num_labels=args.num_labels,
            proj_dim=224,
            dropout_prob=0.1
        )
    else:  # multimodal
        model = MultiModalClassifier(
            text_model_name=args.text_model,
            image_model_name=args.image_model,
            num_labels=args.num_labels,
            hidden_dim=224,
            dropout_prob=0.1
        )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics
    )

    trainer.train()

    logging.info("Saving model")
    model_path = os.path.join(
        args.output_dir, f"{args.model_type}_model_{args.num_labels}way")
    trainer.save_model(model_path)

    metadata = {
        "model_type": args.model_type,
        "text_model_name": args.text_model,
        "image_model_name": args.image_model,
        "image_model_input_dim": str(args.image_model_input_dim),
        "num_labels": str(args.num_labels),
        "hidden_dim": "224",
        "dropout_prob": "0.1"
    }

    model_state = model.state_dict()
    safetensors_path = os.path.join(model_path, "model.safetensors")
    save_file(model_state, safetensors_path, metadata=metadata)

    logging.info(f"Model saved to {model_path} with metadata")


if __name__ == '__main__':
    main()
