import numpy as np
import torch

from pathlib import Path
import cv2
import skimage.io as skio

from torchvision.ops import box_iou
from scipy.spatial import distance_matrix


from datasets import get_sequences, SequentialDataset

import torch
import numpy as np
from torch.utils.data import DataLoader

from deep_sort_realtime.deepsort_tracker import DeepSort

import motmetrics as mm


from datasets import create_lists

from torchmetrics.detection.mean_ap import MeanAveragePrecision


from transformers import DetrImageProcessor


from training import get_detr_inputs


def change_bounding_box_format(predictions, use_masks):
    """
    Modified the bounding box format into a list of tuples ([box-coords], confidence, label)
    Also removes low confidence predictions.
    """
    boxes = predictions["boxes"]
    labels = predictions["labels"]
    scores = predictions["scores"]

    if use_masks:
        masks = predictions["masks"]

    filtered_masks = []

    result = []

    for i in range(len(boxes)):

        result.append(([
            boxes[i][0].item(),
            boxes[i][1].item(),
            (boxes[i][2] - boxes[i][0]).item(),
            (boxes[i][3] - boxes[i][1]).item()],
            scores[i].item(),
            labels[i].item()))

        if use_masks:
            filtered_masks.append(
                np.bool(masks[i].cpu() > 0.5)[0, :, :])

    return result, filtered_masks


def bbox_get_polygon_format(predictions):
    """
    Modified the bounding box format into a list of tuples ([box-coords], confidence, label)
    Also removes low confidence predictions.
    """
    boxes = []
    labels = []
    confidences = []
    for i, box in enumerate(predictions["boxes"]):
        boxes.append(box[0].item())
        boxes.append(box[1].item())
        boxes.append(box[2].item())
        boxes.append(box[3].item())

        labels.append(predictions["labels"][i].item())

        confidences.append(predictions["scores"][i].item())

    return [boxes, labels, confidences]


def compute_box_distances_iou(detections, ground_truth):
    prediction_boxes_tensor = torch.FloatTensor(size=(len(detections), 4))

    for i, det in enumerate(detections):
        prediction_boxes_tensor[i, :] = torch.FloatTensor(det)

    gt_boxes_tensor = ground_truth * 1.0

    iou = 1 - box_iou(prediction_boxes_tensor, gt_boxes_tensor)

    iou[iou > 0.99] = np.nan

    return iou


def compute_box_distances(detections, ground_truth):
    prediction_boxes_tensor = torch.FloatTensor(size=(len(detections), 4))

    for i, det in enumerate(detections):
        prediction_boxes_tensor[i, :] = torch.FloatTensor(det)

    gt_boxes_tensor = ground_truth * 1.0

    pred_center = (
        prediction_boxes_tensor[:, 2:4] - prediction_boxes_tensor[:, 0:2]) / 2.0

    gt_center = (gt_boxes_tensor[:, 2:4] - gt_boxes_tensor[:, 0:2]) / 2.0

    center_distance = distance_matrix(pred_center, gt_center)

    center_distance[center_distance > 100] = np.nan

    return center_distance


def draw_box(image, track, COLORS):

    ltrb = tuple(track.to_ltrb())
    left = int(ltrb[0])
    top = int(ltrb[1])
    right = int(ltrb[2])
    bottom = int(ltrb[3])

    # color is selected based on ID
    box_color = COLORS[(int(track.track_id) % (len(COLORS)))]

    # bounding box
    cv2.rectangle(image, (left, top), (right, bottom), color=box_color)

    # id printed on top of the bounding box
    width = min(60, right - left)
    text_position = (left, top - 2)
    cv2.rectangle(image, (left, top - 15), (left + width, top),
                  color=box_color, thickness=-1)

    image = cv2.putText(image, f"{track.get_det_class()}-{track.track_id}", text_position,
                        fontFace=0, fontScale=0.45, color=(0, 0, 0))

    return image


def filter_low_confidence(detections, threshold, use_masks):
    indices = detections["scores"] > threshold

    detections["boxes"] = detections["boxes"][indices]
    if use_masks:
        detections["masks"] = detections["masks"][indices]
    detections["labels"] = detections["labels"][indices]
    detections["scores"] = detections["scores"][indices]

    return detections


def save_tracking_output(image, image_path, output_folder_name):
    file_name = image_path.stem
    file_suffix = image_path.suffix
    folder_name = image_path.parent
    output_folder = Path(output_folder_name) / folder_name
    output_path = output_folder / \
        Path(file_name + "_bboxes" + file_suffix)
    output_folder.mkdir(parents=True, exist_ok=True)

    skio.imsave(output_path, image)


def scale_and_normalize(image):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    return (image / 255.0 - mean) / std


# usage described here:
# https://pypi.org/project/deep-sort-realtime/
def track(root_dir, model, is_transformer, save_output, output_folder, colors, confidence_threshold, dev):

    mot_metrics = []
    map_metrics = []

    # initialize accumulator for detection metrics (mean average precision)
    mean_average_precision = MeanAveragePrecision(
        iou_type="bbox", class_metrics=True)

    detr_processor = None
    if is_transformer:
        detr_processor = DetrImageProcessor().from_pretrained("facebook/detr-resnet-50")

    for sequence_path in get_sequences(root_dir):

        # create dataset for the current sequence
        sequence_ds = SequentialDataset(
            sequence_path, normalize=False, augment=False, classes_from_zero=is_transformer, dev=dev)
        sequence_dl = DataLoader(
            sequence_ds, shuffle=False, batch_size=1, collate_fn=create_lists)

        # initialize tracker
        ds_tracker = DeepSort(
            max_age=4, n_init=4, max_iou_distance=0.99, max_cosine_distance=0.3, bgr=False)

        # initialize tracking metrics accumulator
        acc = mm.MOTAccumulator(auto_id=True)

        for image_tensor, gt, image, image_path in sequence_dl:
            model.eval()

            # run prediction
            with torch.no_grad():
                if not is_transformer:
                    detections = model(image_tensor)[0]

                else:
                    gt_copy = gt.copy()
                    shapes = [(im.shape[1], im.shape[2])
                              for im in image_tensor]
                    inputs = get_detr_inputs(
                        detr_processor, image_tensor, gt_copy, dev)
                    outputs = model(**inputs)
                    detections = detr_processor.post_process_object_detection(
                        outputs, threshold=0, target_sizes=shapes)[0]

            # remove batch of size 1, its only necessary for the model input
            image_path = image_path[0]
            image = image[0]
            gt = gt[0]
            image_tensor = None

            # MAP is updated keeping low confidence detections
            # register new detection to metric update
            mean_average_precision.update([detections], [gt])

            # low conf detections are removed for tracking
            use_masks = not is_transformer

            detections = filter_low_confidence(
                detections, confidence_threshold, use_masks=use_masks)

            detections, masks = change_bounding_box_format(
                detections, use_masks=use_masks)

            # update tracks with current detections and the current frame
            tracks = ds_tracker.update_tracks(
                detections, frame=scale_and_normalize(image))  # instance_masks=masks)

            tracked_ids = []
            tracked_boxes = []

            # draw boxes into image and save tracked objects to a list
            for track in tracks:
                # described in https://pypi.org/project/deep-sort-realtime/
                if not track.is_confirmed():
                    continue

                if save_output:
                    draw_box(image, track, colors)

                tracked_boxes.append(
                    list(map((lambda x: float(x)), list(track.to_ltrb()))))
                tracked_ids.append(int(track.track_id))

            if save_output:
                # save image
                save_tracking_output(image, image_path, output_folder)

            # compute distances and update motmetrics evaluation
            gt_object_ids = gt["indices"].cpu().tolist()
            gt_boxes = gt["boxes"].cpu()
            distances = np.asarray(
                compute_box_distances_iou(tracked_boxes, gt_boxes))

            acc.update(oids=gt_object_ids,
                       hids=tracked_ids, dists=distances)

        mot_metrics.append(acc)

    map_metrics.append(mean_average_precision.compute())

    return mot_metrics, map_metrics
