from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
import torch


def undo_normalization(image):
    mean = torch.tensor((0.485, 0.456, 0.406))
    std = torch.tensor((0.229, 0.224, 0.225))

    return image * std[:, None, None] + mean[:, None, None]


def plot_boxes(image, boxes):

    image = torch.permute(image.detach().cpu(), (0, 1, 2))
    # image = undo_normalization(image)
    img = draw_bounding_boxes(image, boxes, colors="red", width=2)

    figure, axis = plt.subplots(figsize=(20, 8))

    axis.imshow(to_pil_image(img))


def plot_masks(image, masks):

    image = torch.permute(image.detach().cpu(), (0, 1, 2))
    # image = undo_normalization(image)
    img = draw_segmentation_masks(image, masks, alpha=0.8, colors="red")

    figure, axis = plt.subplots(figsize=(20, 8))

    axis.imshow(to_pil_image(img))
