from training import train
from datasets import DetectionDataset, create_lists

from torchvision.models.detection import (maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights)
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader


# configuration
LR = 1e-5
BACKBONE_LR = 1e-6
WEIGHT_DECAY = 1e-5
EPOCHS = 10
BATCH_SIZE = 2
NUM_CLASSES = 2
CONFIDENCE_THRESHOLD = 0.5
MODEL_PATH = "models/model_maskrcnn_final"
MODEL_OUTPUT_PATH = "models/model_maskrcnn_final"


if __name__ == "__main__":
    dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    train_ds = DetectionDataset("training", normalize=False, 
                                augment=True, classes_from_zero=False, dev=dev)
    val_ds = DetectionDataset("validation", normalize=False, 
                              augment=False, classes_from_zero=False, dev=dev)
    
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=create_lists)

    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    collate_fn=create_lists)
    

    # load model and change last layer
    model = maskrcnn_resnet50_fpn_v2(
    weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)

    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_channels=256,
                                                    dim_reduced=256,
                                                    num_classes=NUM_CLASSES + 1)

    model.roi_heads.box_predictor = FastRCNNPredictor(in_channels=1024,
                                                    num_classes=NUM_CLASSES + 1)
    
    model = model.to(dev)

    optim = Adam([{"params": model.backbone.parameters(), "lr": BACKBONE_LR},
              {"params": model.rpn.parameters(), "lr": LR},
              {"params": model.roi_heads.parameters(), "lr": LR}], lr=LR, weight_decay=WEIGHT_DECAY)
    

    # larger than any initial loss    
    best_loss = 10000

    saved_model = torch.load(MODEL_PATH, weights_only=False)

    model.load_state_dict(saved_model["model_state_dict"])
    optim.load_state_dict(saved_model["optimizer_state_dict"])
    best_loss = saved_model["best_loss"]

    saved_model = None

    best_loss, training_losses, validation_losses = train(model, train_dl, val_dl, optim, EPOCHS,  
                                                          output_path=MODEL_OUTPUT_PATH, device=dev,
                                                        best_loss=best_loss)