from tqdm.notebook import tqdm
import torch
from transformers import DetrImageProcessor


def annotation_to_detr_format(image, gt):
    """
    Transforms the format expected by MaskRCNN to the one expected by DETR.
    """

    annots = []

    for i, bbox in enumerate(gt["boxes"]):

        xywh_box = bbox.cpu()
        xywh_box[2] = xywh_box[2] - xywh_box[0]
        xywh_box[3] = xywh_box[3] - xywh_box[1]

        annots.append(dict({
            "bbox": xywh_box,
            "category_id": gt["labels"][i].cpu(),
            "area": (image.shape[1] * image.shape[2])
        }))

    annotation = dict({
        "image_id": gt["image_id"],

        "annotations": annots
    })

    return annotation


def batch_to_detr_format(images, gts):
    for i, gt in enumerate(gts):
        gts[i] = annotation_to_detr_format(images[i], gt)

    return images, gts


def get_detr_inputs(detr_processor, im, gt, device):
    images, annotations = batch_to_detr_format(im, gt)

    inputs = detr_processor.preprocess(
        images=images, annotations=annotations,
        do_rescale=False,
        do_normalize=True,
        do_resize=True,
        return_tensors="pt").to(device)

    labels = []
    for label in inputs["labels"]:
        labels.append(dict({key: val.to(device)
                            for key, val in label.items()}))
    inputs["labels"] = labels

    return inputs


def train(model, train_dl, val_dl, optim, epochs, output_path, device, best_loss=0, transformer=False):
    detr_processor = None

    training_losses = []
    validation_losses = []

    if transformer:
        detr_processor = DetrImageProcessor().from_pretrained("facebook/detr-resnet-50")

    for epoch in tqdm(range(epochs), desc="Epoch:", position=0):
        model.train()

        i = 1
        training_loss = 0

        for im, gt in tqdm(train_dl, desc="Batch:", position=1, leave=False):

            optim.zero_grad()

            if transformer:
                inputs = get_detr_inputs(detr_processor, im, gt, device)
                pred = model(**inputs)
                loss = pred.loss
            else:
                pred = model(im, gt)
                loss = sum(l for l in pred.values())

            loss.backward()
            optim.step()

            training_loss += loss
            i += 1


        print(f"Training loss: {training_loss / i}")

        training_losses.append(training_loss / i)

        with torch.no_grad():

            avg_loss = 0
            n = 1

            for im, gt in tqdm(val_dl, desc="Batch:", position=1):
                if transformer:
                    inputs = get_detr_inputs(detr_processor, im, gt, device)
                    pred = model(**inputs)
                    loss = pred.loss
                else:
                    pred = model(im, gt)
                    loss = sum(l for l in pred.values())

                n += 1
                avg_loss += loss

            avg_loss /= n

            validation_losses.append(avg_loss)

            print(f"Average loss on validation data: {avg_loss}")

            if avg_loss < best_loss:
                best_loss = avg_loss

                print(f"Saving new best model with loss: {best_loss}")

                torch.save({
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optim.state_dict(),
                    "best_loss": best_loss
                }, output_path)

    return best_loss, training_losses, validation_losses
