#!/usr/bin/env python3
import argparse
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image, ImageFile
from transformers import (
    AutoTokenizer,
    BertTokenizer,
    DebertaTokenizer
)
import os

# Import models from training.py
from training import MultiModalClassifier

from safetensors.torch import load_file, safe_open

# Ensure truncated images can be loaded
ImageFile.LOAD_TRUNCATED_IMAGES = True


def classify(text, image_path, model, tokenizer, image_transform, device, max_length=128):
    encoded_text = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=max_length,
        truncation=True,
        padding='max_length',
        return_tensors='pt'
    )

    try:
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        print(f"Error loading image: {e}")
        return None, None
    pixel_values = image_transform(image).unsqueeze(0).to(device)
    input_ids = encoded_text['input_ids'].to(device)
    attention_mask = encoded_text['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask, pixel_values=pixel_values)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs
        probs = F.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1)

    return pred.item(), probs.cpu().numpy()


def main():
    parser = argparse.ArgumentParser(
        description='Multimodal fake news detection inference')
    parser.add_argument('--text', type=str, required=True,
                        help='Text content to classify')
    parser.add_argument('--image', type=str, required=True,
                        help='Path to the image file')
    parser.add_argument('--model_path', type=str, required=True,
                        help='Folder with the saved model')

    args = parser.parse_args()

    model_path = os.path.join(args.model_path, 'model.safetensors')

    loaded_state = load_file(model_path)
    metadata = safe_open(model_path, framework="pt").metadata()

    if metadata['model_type'] != 'multimodal':
        print(
            f"Model must be multimodal! Model type is {metadata['model_type']}.")
        return

    image_transform = transforms.Compose([
        transforms.Resize(
            (int(metadata['image_model_input_dim']), int(metadata['image_model_input_dim']))),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Load tokenizer based on text model
    if 'deberta' in metadata["text_model_name"]:
        tokenizer = DebertaTokenizer.from_pretrained(
            metadata["text_model_name"])
    elif 'bert' in metadata["text_model_name"] and 'modernbert' not in metadata["text_model_name"]:
        tokenizer = BertTokenizer.from_pretrained(metadata["text_model_name"])
    else:
        tokenizer = AutoTokenizer.from_pretrained(metadata["text_model_name"])

    model = MultiModalClassifier(
        text_model_name=metadata["text_model_name"],
        image_model_name=metadata["image_model_name"],
        num_labels=int(metadata["num_labels"])
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the weights
    model.load_state_dict(loaded_state)
    model.to(device)
    model.eval()

    if int(metadata["num_labels"]) == 2:
        label_mapping = {0: "Fake", 1: "True"}
    else:  # 6-way classification
        label_mapping = {
            0: "True",
            1: "Satire/Parody",
            2: "Misleading Content",
            3: "Imposter Content",
            4: "False Connection",
            5: "Manipulated Content"
        }

    # Run inference
    predicted_class, probabilities = classify(
        args.text,
        args.image,
        model,
        tokenizer,
        image_transform,
        device
    )

    if predicted_class is None:
        print("Error during classification. Please check the input.")
        return

    print("\nPrediction Results:")
    print(
        f"Predicted Class: {predicted_class} ({label_mapping[predicted_class]})")
    print(f"Confidence: {probabilities[0][int(predicted_class)]:.4f}")

    print("\nClass Probabilities:")
    for i, prob in enumerate(probabilities[0]):
        print(f"{label_mapping[i]}: {prob:.4f}")


if __name__ == "__main__":
    main()
