from training import train
from datasets import DetectionDataset, create_lists

from transformers import DetrForObjectDetection

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader


# configuration
LR = 2e-5
BACKBONE_LR = 1e-6
WEIGHT_DECAY = 1e-5

EPOCHS = 10
BATCH_SIZE = 8
NUM_CLASSES = 2
CONFIDENCE_THRESHOLD = 0.5
MODEL_PATH = "models/model_transformer_final"
MODEL_OUTPUT_PATH = "models/model_transformer_final"


if __name__ == "__main__":
    dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    train_ds = DetectionDataset("training", normalize=False, 
                                augment=True, classes_from_zero=True, dev=dev)
    val_ds = DetectionDataset("validation", normalize=False, 
                              augment=False, classes_from_zero=True, dev=dev)

    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=create_lists)

    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    collate_fn=create_lists)
    
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", 
                                                   num_labels=NUM_CLASSES, 
                                                   ignore_mismatched_sizes=True)

    model = model.to(dev)

    # set backbone params to its special learning rate and all other parameters to another learning rate
    optim = Adam([{"params": [param for (param_str, param) in model.named_parameters()
                          if param_str not in
                          [n for n, _ in model.model.backbone.named_parameters(prefix="model.backbone")]],
               "lr": LR},
              {"params": model.model.backbone.parameters(),
               "lr": BACKBONE_LR}],
             lr=LR,
             weight_decay=WEIGHT_DECAY)
    
    # larger than initial loss
    best_loss = 1000

    # load saved model together with optimizer state    
    saved_model = torch.load(MODEL_PATH, weights_only=False)

    model.load_state_dict(saved_model["model_state_dict"])
    optim.load_state_dict(saved_model["optimizer_state_dict"])

    best_loss = saved_model["best_loss"]

    saved_model = None

    ## train the model
    best_loss, training_losses, validation_losses = train(model, train_dl, val_dl, optim, EPOCHS, 
                                                          output_path=MODEL_OUTPUT_PATH, device=dev,
                                                      best_loss=best_loss, transformer=True)