#!/usr/bin/env python3
import argparse
import os
import torch
import pandas as pd
import torchvision.transforms as transforms
from transformers import (
    AutoTokenizer,
    BertTokenizer,
    DebertaTokenizer,
    Trainer,
    TrainingArguments
)
import logging

# Import models from training.py
from training import (
    TextClassifier,
    ImageClassifier,
    MultiModalClassifier,
    FakedditDataset,
    compute_metrics,
    get_datasets
)
from safetensors.torch import load_file, safe_open


def main():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
    )

    parser = argparse.ArgumentParser(
        description='Evaluate fake news detection models')
    parser.add_argument('--model_path', type=str, required=True,
                        help='Folder with the saved model')
    parser.add_argument('--output_dir', type=str, default='./results',
                        help='Directory to save evaluation results')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='Evaluation batch size')

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # 6-way labels dataset must be read in either case
    df = pd.read_csv('fakeddit_with_paths.tsv',
                     delimiter='\t', encoding='ISO-8859-1')
    df = df[['id', 'clean_title', 'image_url', '2_way_label', '6_way_label']]
    df = df.dropna()
    df = df.rename(columns={'clean_title': 'Title', 'image_url': 'Image',
                   '2_way_label': 'Label2', '6_way_label': 'Label6'})

    model_path = os.path.join(args.model_path, 'model.safetensors')

    loaded_state = load_file(model_path)
    metadata = safe_open(model_path, framework="pt").metadata()

    metadata['model_type'] = metadata.get('model_type', 'multimodal')

    if int(metadata["num_labels"]) == 2:
        # 2-way labels dataset
        df_2way = pd.read_csv('ifnd_with_paths.csv',
                              delimiter=',', encoding='windows-1252')
        df_2way = df_2way[['id', 'Statement', 'Image', 'Label']]
        df_2way = df_2way.dropna()
        df_2way = df_2way.rename(
            columns={'Statement': 'Title', 'Label': 'Label2'})
        df_2way['Label2'] = df_2way['Label2'].replace({'TRUE': 1, 'Fake': 0})

        df_sample_temp = df.sample(min(20000, df.shape[0]), random_state=42)
        df_sample_temp = df_sample_temp[['id', 'Title', 'Image', 'Label2']]
        df_2way = pd.concat([df_2way, df_sample_temp], ignore_index=True)
        df = df_2way.rename(columns={'Label2': 'Label'})
    else:
        df = df.rename(columns={'Label6': 'Label'})

    _, _, test_df = get_datasets(df)

    image_transform = transforms.Compose([
        transforms.Resize(
            (int(metadata['image_model_input_dim']), int(metadata['image_model_input_dim']))),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Load tokenizer based on text model
    if 'deberta' in metadata["text_model_name"]:
        tokenizer = DebertaTokenizer.from_pretrained(
            metadata["text_model_name"])
    elif 'bert' in metadata["text_model_name"] and 'modernbert' not in metadata["text_model_name"]:
        tokenizer = BertTokenizer.from_pretrained(metadata["text_model_name"])
    else:
        tokenizer = AutoTokenizer.from_pretrained(metadata["text_model_name"])

    test_dataset = FakedditDataset(test_df, tokenizer, image_transform)

    if metadata['model_type'] == 'text':
        model = TextClassifier(
            text_model_name=metadata["text_model_name"],
            num_labels=int(metadata["num_labels"])
        )
    elif metadata['model_type'] == 'image':
        model = ImageClassifier(
            image_model_name=metadata["image_model_name"],
            num_labels=int(metadata["num_labels"])
        )
    else:  # multimodal
        model = MultiModalClassifier(
            text_model_name=metadata["text_model_name"],
            image_model_name=metadata["image_model_name"],
            num_labels=int(metadata["num_labels"])
        )
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load the weights
    model.load_state_dict(loaded_state)
    model.to(device)
    model.eval()

    eval_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_eval_batch_size=args.batch_size,
        do_eval=True,
        report_to='none'
    )

    trainer = Trainer(
        model=model,
        args=eval_args,
        eval_dataset=test_dataset,
        compute_metrics=compute_metrics
    )

    metrics = trainer.evaluate()
    logging.info("Evaluation Metrics:")
    logging.info(f"Test Accuracy: {metrics['eval_accuracy']:.4f}")
    logging.info(f"Precision: {metrics['eval_precision']:.4f}")
    logging.info(f"Recall: {metrics['eval_recall']:.4f}")
    logging.info(f"F1 Score: {metrics['eval_f1']:.4f}")


if __name__ == '__main__':
    main()
