import torch

import motmetrics as mm

from transformers import DetrForObjectDetection

import seaborn as sns
from random import shuffle

from utils import track


NUM_CLASSES = 2
IS_TRANSFORMER = True
CONFIDENCE_THRESHOLD = 0.5

# path to the model used for evaluation
MODEL_PATH = "models/model_transformer_final"

# whether to save images with drawn bounding boxes (predictions) to disk
SAVE_OUTPUT = False

# folder where predictions will be saved 
# #(directory strcture will be the same same as the training data)
OUTPUT_FOLDER = "outputs_transformer"


if __name__ == "__main__":
    dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    COLORS = dict()

    COLORS = [(int(b * 255), int(g * 255), int(r * 255))
            for (r, g, b) in list(sns.color_palette("hls", 16))]

    shuffle(COLORS)

    model = DetrForObjectDetection.from_pretrained(
    "facebook/detr-resnet-50", num_labels=NUM_CLASSES, ignore_mismatched_sizes=True)

    saved_model = torch.load(MODEL_PATH, weights_only=True)

    # throws a warning because the pretrained model has a different number of classes
    model.load_state_dict(saved_model["model_state_dict"], assign=True)

    model = model.to(dev)

    model.eval()

    print("Running detection model and tracker...")

    mot_metrics, map_metrics = track(
        "testing", model,
        is_transformer=IS_TRANSFORMER,
        save_output=SAVE_OUTPUT,
        output_folder=OUTPUT_FOLDER,
        confidence_threshold=CONFIDENCE_THRESHOLD,
        colors=COLORS,
        dev=dev)
    

    print()
    # usage descirbed in https://pypi.org/project/motmetrics/
    i = 1
    for acc in mot_metrics:
        print("-"*50)
        print(f"MOT metrics for sequence number {i}")
        i += 1 
        mh = mm.metrics.create()
        metrics = mh.compute(acc, return_dataframe=False)
        
        precision = metrics['precision'].item()
        recall = metrics["recall"].item()
        mota = metrics["mota"].item()
        motp = metrics["motp"].item()

        print(f"Precision: {precision}, Recall: {recall}, MOTA: {mota}, MOTP: {motp}")

    for map_metric in map_metrics:
        print()
        print("Detection metrics (ignoring labels)")

        map50 = map_metric["map_50"].item()
        map75 = map_metric["map_75"].item()
        map_avg = map_metric["map"].item()
        map_car = map_metric["map_per_class"][0].item()
        map_pedestrian = map_metric["map_per_class"][1].item()

        print(f"mAP: {map_avg}, mAP@50: {map50}, mAP@75: {map75}," + \
              f"mAP_car: {map_car}, mAP_pedestrian: {map_pedestrian}")