In [None]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib.patches as patches

import albumentations
from albumentations import Compose, Resize, BboxParams, Crop
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader
from skimage import io
from skimage.draw import rectangle_perimeter

from transformers import DetrImageProcessor, DetrForObjectDetection

from pathlib import Path

from tqdm import tqdm

from torchvision.utils import draw_bounding_boxes
from torchvision.io import read_image

import os

In [None]:
labels = pd.read_csv('samples.csv')
labels = labels[labels.image_id != '1395.tif'] # removing image that isn't included in the data

In [None]:
# setting the label maps

x = 0
m_labels = dict()
for v in labels.type_name.unique():
    m_labels[x] = v
    x+=1
m_labels_r = {
    label : key for key, label in m_labels.items()
}

In [None]:
processor = DetrImageProcessor()
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", num_labels=len(m_labels_r), ignore_mismatched_sizes=True, num_queries=250)

In [None]:
class XviewDS(Dataset):
    def __init__(self, df, labelmap, processor):
        self.df = df
        self.labelmap = labelmap
        self.processor = processor
        self.image_names = self.df['image_id'].unique()

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, index):
        img_id = self.image_names[index]
        path = f"train_images/{img_id}"

        img = io.imread(path)

        annotations = []

        he, wi, c = img.shape
        df = self.df[self.df['image_id'] == img_id]

        # split image
        w_step = wi // 3
        h_step = he // 3

        h_indices = []
        w_indices = []
        for i in range(0, wi+1, w_step):
            w_indices.append(i)

        for i in range(0, he+1, h_step):
            h_indices.append(i)

        h_coords = list(zip(h_indices, h_indices[1:]))
        w_coords = list(zip(w_indices, h_indices[1:]))

        patches_coords = [(a,b,c,d) for (a,b) in h_coords for (c,d) in w_coords]

        out = []
        # makign sure the coordinates are valid
        for coords in patches_coords:
            ya, yb, xa, xb = coords
            ya = max(0, ya)
            xa = max(0,xa)
            yb = min(yb, he)
            xb = min(xb, wi)

            cdf = df[(df.minc >= xa) & (df.maxc <= xb) & (df.minr >= ya) & (df.minr <= yb) ]
            if (xa >= xb or ya >= yb): continue
            t = Crop(xa, ya, xb, yb)
            cutoff_image = t(image=img)['image']

            for i in range(len(cdf)):
                item = cdf.iloc[i]

                x, y = max(xa,item.minc), max(ya,item.minr)
                x1, y1 = min(item.maxc,xb) , min(yb,item.maxr)

                if (x1 - x < 8 or y1 - y < 8): # keep only boxes of at least some size
                    continue

                if x >= x1 or y1 >= y:
                    continue

                x = x - xa
                y = y - ya
                x1 = x1 - xa
                y1 = y1 - ya

                w = x1 - x
                h = y1 - y
                a = {}

                a['bbox'] = np.array([x, y, w, h])
                a['category_id'] = self.labelmap[item['type_name']]
                a['area'] = w*h
                annotations.append(a)

            image_id = img_id
            image_id = int(image_id[:image_id.index('.')])
            target = {'image_id': image_id, 'img': cutoff_image, 'annotations': annotations}
            encoding = self.processor(images=cutoff_image, annotations=target, return_tensors="pt")
            pixel_values = encoding["pixel_values"].squeeze()
            target = encoding["labels"][0]

            target['boxes'] = target['boxes'][~torch.any(target['boxes'].isnan(),dim=1)] # remove NaN values that began to pop up during training
            out.append((pixel_values, target))
        return out

In [None]:
ds = XviewDS(labels, m_labels_r, processor)

train_ds, val_ds = random_split(ds, [0.8, 0.2])

In [None]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  b = []
  for bt in batch:
      pixel_values = [item[0] for item in bt]
      encoding = processor.pad(pixel_values, return_tensors="pt")
      labels = [item[1] for item in bt]
      bt = {}
      bt['pixel_values'] = encoding['pixel_values']
      bt['pixel_mask'] = encoding['pixel_mask']
      bt['labels'] = labels
      b.append(bt)
  return b

train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=1, shuffle=True)

val_dl = DataLoader(val_ds, collate_fn = collate_fn, batch_size=2)

In [None]:
# model.load_state_dict(torch.load('./vitmodel.pt') # optional loading from saved model

In [None]:
def train(model, opt, dl):
    model.train()

    for data in tqdm(dl, position=0,leave=True):
        for d in tqdm(data, position=0, leave=True):
            pixel_values = d['pixel_values']
            pixel_mask = d["pixel_mask"]
            labels = [{k: v.to(dev) for k, v in t.items()} for t in d["labels"]]

            out = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
            loss = out.loss
            loss.backward()
            opt.step()

            print(loss)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr= 0.00001)

In [None]:
epochs = 5

# training
# takes a long time as its a transformer
for epoch in tqdm(range(epochs), position=0, leave=True):
    print(f"Epoch {epoch+1}")
    train(model, optimizer, train_dataloader)

In [None]:
def print_bboxes(image, boxes):
    for bbox in boxes:
        x_min, y_min, x_max, y_max = bbox
        if x_max > image.shape[1]:
            x_max = image.shape[1] - 1
        if y_max > image.shape[0]:
            y_max = image.shape[0] - 1
        rr, cc = rectangle_perimeter(start=(y_min, x_min), end=(y_max, x_max), shape=image.shape, clip=True)
        image[rr, cc] = (255, 0, 0)

    plt.imshow(image)
    plt.show()

In [None]:
def test_image(image_id, df, print_results=False, threshold=0.3):
    # used to print out the ground truth image and the predicted image bboxes
    with torch.no_grad():
        filename = f'train_images/{image_id}.tif'
        test_img = io.imread(filename)
        inputs = processor(images=test_img, return_tensors='pt')

        outputs = model(**inputs)

        target_sizes = torch.tensor([test_img.shape[:-1]])
        results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]

        if print_results:
            print(results) # outputting the evaluation results, including the confidence scores

        img_labels = df[df.image_id == f'{image_id}.tif']

        tboxes = img_labels[['minc', 'minr', 'maxc', 'maxr']].apply(tuple, axis=1).tolist()
        tboxes = [(a,b,c,d) for (a,b,c,d) in tboxes if a > 0 and b >0 and c > 0 and d>0]

        test_img2 = test_img
        print_bboxes(test_img, tboxes)

        pboxes = torch.round(results['boxes'])

        print_bboxes(test_img2, pboxes)



In [None]:
# example usage
test_image(20, labels) # used to show the image that is included in the lab report

In [None]:
# UNUSED METHODS
# methods that ended up being unused
# but could be used to calculate IoU scores, accuracy and precision

In [None]:
import torch

def calculate_iou(box_a, box_b):
    max_xy = torch.min(box_a[:, 2:].unsqueeze(1), box_b[:, 2:].unsqueeze(0))
    min_xy = torch.max(box_a[:, :2].unsqueeze(1), box_b[:, :2].unsqueeze(0))
    inter = torch.clamp((max_xy - min_xy), min=0)
    inter_area = inter[:, :, 0] * inter[:, :, 1]

    area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter_area)
    area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter_area)
    union_area = area_a + area_b - inter_area

    iou = inter_area / union_area

    return iou


def calculate_average_precision(pred_boxes, true_boxes, iou_threshold=0.1):
    pred_boxes = pred_boxes[pred_boxes[:, -1].argsort(dim=0)]

    num_pred_boxes = len(pred_boxes)
    num_true_boxes = len(true_boxes)

    if num_pred_boxes == 0 or num_true_boxes == 0:
        return 0.0

    iou = calculate_iou(pred_boxes, true_boxes)

    true_positives = torch.zeros(num_pred_boxes)
    false_positives = torch.zeros(num_pred_boxes)
    true_box_flags = torch.zeros(num_true_boxes)

    for i in range(num_pred_boxes):
        pred_box = pred_boxes[i]
        best_iou = 0.0
        best_match_idx = -1

        for j in range(num_true_boxes):
            if iou[i, j] > best_iou and not true_box_flags[j]:
                best_iou = iou[i, j]
                best_match_idx = j

        if best_iou > iou_threshold:
            true_positives[i] = 1.0
            true_box_flags[best_match_idx] = 1.0
        else:
            false_positives[i] = 1.0

    true_positives = torch.cumsum(true_positives, dim=0)
    false_positives = torch.cumsum(false_positives, dim=0)

    precision = true_positives / (true_positives + false_positives)
    recall = true_positives / num_true_boxes

    average_precision = 0.0
    num_recall_points = 11
    recall_thresholds = torch.linspace(0, 1, num_recall_points)

    for recall_threshold in recall_thresholds:
        recall_mask = recall >= recall_threshold
        if recall_mask.any().tolist():
            precision_masked = precision[recall_mask]
            average_precision += precision_masked.max()

    average_precision /= num_recall_points

    return average_precision.item()


def calculate_accuracy(pred_boxes, true_boxes, iou_threshold=0.1):
    num_pred_boxes = len(pred_boxes)
    num_true_boxes = len(true_boxes)

    if num_pred_boxes == 0 or num_true_boxes == 0:
        return 0.0

    iou = calculate_iou(pred_boxes, true_boxes)

    true_positives = torch.zeros(num_pred_boxes)
    true_box_flags = torch.zeros(num_true_boxes)

    for i in range(num_pred_boxes):
        pred_box = pred_boxes[i]
        best_iou = 0.0
        best_match_idx = -1

        for j in range(num_true_boxes):
            if iou[i, j] > best_iou and not true_box_flags[j]:
                best_iou = iou[i, j]
                best_match_idx = j

        if best_iou > iou_threshold:
            true_positives[i] = 1.0
            true_box_flags[best_match_idx] = 1.0

    accuracy = torch.sum(true_positives) / num_true_boxes

    return accuracy.item()

In [None]:
for a in tqdm(val_dl):
    total = 0
    ap = 0
    tac = 0
    for xb in tqdm(a):
        bxs = [x['boxes'] for x in xb['labels']]
        bxs = torch.cat(bxs,dim=0)
        labels = [{k:v.to('cpu') for k,v in t.items()} for t in xb['labels']]
        with torch.no_grad():
            out = model(pixel_values=xb['pixel_values'].to('cpu'), pixel_mask=xb['pixel_mask'].to('cpu'))
        os = torch.stack([target['orig_size'] for target in labels], dim=0)

        tmp = out['pred_boxes'].reshape(-1, 4)
        p = calculate_average_precision(tmp, bxs, iou_threshold=0.1)

        tac += ac
        ap += p
        total += 1
    print(ap/total) # average precision
    print(ac/total) # average accuracy
