import torch

from torchvision.models.detection import (maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights)
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import motmetrics as mm

import seaborn as sns
from random import shuffle

from utils import track

NUM_CLASSES = 2
IS_TRANSFORMER = False
CONFIDENCE_THRESHOLD = 0.5

# path to the model used for evaluation
MODEL_PATH = "models/model_maskrcnn_final"

# whether to save images with drawn bounding boxes (predictions) to disk
SAVE_OUTPUT = False

# folder where predictions will be saved 
# #(directory strcture will be the same same as the training data)
OUTPUT_FOLDER = "outputs"


if __name__ == "__main__":
    dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    COLORS = dict()

    COLORS = [(int(b * 255), int(g * 255), int(r * 255))
            for (r, g, b) in list(sns.color_palette("hls", 16))]

    shuffle(COLORS)
    
    model = maskrcnn_resnet50_fpn_v2(
    weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)

    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_channels=256,
                                                    dim_reduced=256,
                                                    num_classes=NUM_CLASSES + 1)

    model.roi_heads.box_predictor = FastRCNNPredictor(in_channels=1024,
                                                    num_classes=NUM_CLASSES + 1)
    
    saved_model = torch.load(MODEL_PATH, weights_only=False)

    model.load_state_dict(saved_model["model_state_dict"])

    model = model.to(dev)
    model.eval()

    print("Running detection model and tracker...")

    mot_metrics, map_metrics = track(
        "testing", model,
        is_transformer=IS_TRANSFORMER,
        save_output=SAVE_OUTPUT,
        output_folder=OUTPUT_FOLDER,
        confidence_threshold=CONFIDENCE_THRESHOLD,
        colors=COLORS,
        dev=dev)

    
    i = 1
    # usage described in https://pypi.org/project/motmetrics/
    for acc in mot_metrics:
        print("-"*50)
        print(f"MOT metrics for sequence number {i}")
        i += 1 
        mh = mm.metrics.create()
        metrics = mh.compute(acc, return_dataframe=False)
        
        precision = metrics['precision'].item()
        recall = metrics["recall"].item()
        mota = metrics["mota"].item()
        motp = metrics["motp"].item()

        print(f"Precision: {precision}, Recall: {recall}, MOTA: {mota}, MOTP: {motp}")


    for map_metric in map_metrics:
        print()
        print("Detection metrics (ignoring labels)")

        map50 = map_metric["map_50"].item()
        map75 = map_metric["map_75"].item()
        map_avg = map_metric["map"].item()
        map_car = map_metric["map_per_class"][0].item()
        map_pedestrian = map_metric["map_per_class"][1].item()

        print(f"mAP: {map_avg}, mAP@50: {map50}, mAP@75: {map75}," + \
              f"mAP_car: {map_car}, mAP_pedestrian: {map_pedestrian}")