diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 58788437a28..35c8936b457 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -24,7 +24,8 @@ def show(imgs): imgs = [imgs] fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) for i, img in enumerate(imgs): - img = F.to_pil_image(img.to('cpu')) + img = img.detach() + img = F.to_pil_image(img) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) @@ -38,10 +39,14 @@ def show(imgs): from torchvision.utils import make_grid from torchvision.io import read_image +from torchvision.transforms.functional import convert_image_dtype from pathlib import Path + dog1_int = read_image(str(Path('assets') / 'dog1.jpg')) +dog1 = convert_image_dtype(dog1_int, dtype=torch.float) dog2_int = read_image(str(Path('assets') / 'dog2.jpg')) +dog2 = convert_image_dtype(dog2_int, dtype=torch.float) grid = make_grid([dog1_int, dog2_int, dog1_int, dog2_int]) show(grid) @@ -50,16 +55,15 @@ def show(imgs): # Visualizing bounding boxes # -------------------------- # We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an -# image. We can set the colors, labels, width as well as font and font size ! -# The boxes are in ``(xmin, ymin, xmax, ymax)`` format -# from torchvision.utils import draw_bounding_boxes +# image. We can set the colors, labels, width as well as font and font size. +# The boxes are in ``(xmin, ymin, xmax, ymax)`` format. from torchvision.utils import draw_bounding_boxes boxes = torch.tensor([[50, 50, 100, 200], [210, 150, 350, 430]], dtype=torch.float) colors = ["blue", "yellow"] -result = draw_bounding_boxes(dog1_int, boxes, colors=colors, width=5) +result = draw_bounding_boxes(dog1, boxes, colors=colors, width=5) show(result) @@ -71,12 +75,9 @@ def show(imgs): # :func:`~torchvision.models.detection.retinanet_resnet50_fpn`. from torchvision.models.detection import fasterrcnn_resnet50_fpn -from torchvision.transforms.functional import convert_image_dtype -dog1_float = convert_image_dtype(dog1_int, dtype=torch.float) -dog2_float = convert_image_dtype(dog2_int, dtype=torch.float) -batch = torch.stack([dog1_float, dog2_float]) +batch = torch.stack([dog1, dog2]) model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) model = model.eval() @@ -90,8 +91,8 @@ def show(imgs): threshold = .8 dogs_with_boxes = [ - draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > threshold], width=4) - for dog_int, output in zip((dog1_int, dog2_int), outputs) + draw_bounding_boxes(dog, boxes=output['boxes'][output['scores'] > threshold], width=4) + for dog, output in zip((dog1, dog2), outputs) ] show(dogs_with_boxes) @@ -99,33 +100,256 @@ def show(imgs): # Visualizing segmentation masks # ------------------------------ # The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to -# draw segmentation amasks on images. We can set the colors as well as -# transparency of masks. +# draw segmentation amasks on images. Semantic segmentation and instance +# segmentation models have different outputs, so we will treat each +# independently. # -# Here is demo with torchvision's FCN Resnet-50, loaded with -# :func:`~torchvision.models.segmentation.fcn_resnet50`. -# You can also try using -# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) -# or lraspp mobilenet models +# Semantic segmentation models +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We will see how to use it with torchvision's FCN Resnet-50, loaded with +# :func:`~torchvision.models.segmentation.fcn_resnet50`. You can also try using +# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) or +# lraspp mobilenet models # (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`). # -# Like :func:`~torchvision.utils.draw_bounding_boxes`, -# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image -# of dtype `uint8`. +# Let's start by looking at the ouput of the model. Remember that in general, +# images must be normalized before they're passed to a semantic segmentation +# model. from torchvision.models.segmentation import fcn_resnet50 -from torchvision.utils import draw_segmentation_masks model = fcn_resnet50(pretrained=True, progress=False) model = model.eval() -# The model expects the batch to be normalized -batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) -outputs = model(batch) +normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) +output = model(normalized_batch)['out'] +print(output.shape, output.min().item(), output.max().item()) + +##################################### +# As we can see above, the output of the segmentation model is a tensor of shape +# ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and +# we can normalize them into ``[0, 1]`` by using a softmax. After the softmax, +# we can interpret each value as a probability indicating how likely a given +# pixel is to belong to a given class. +# +# Let's plot the masks that have been detected for the dog class and for the +# boat class: + +sem_classes = [ + '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' +] +sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} + +# We normalize the masks of each image in the batch independently +normalized_masks = torch.stack([torch.nn.Softmax(dim=0)(masks) for masks in output]) + +dog_and_boat_masks = [ + normalized_masks[img_idx, sem_class_to_idx[cls]] + for img_idx in range(batch.shape[0]) + for cls in ('dog', 'boat') +] + +show(dog_and_boat_masks) + +##################################### +# As expected, the model is confident about the dog class, but not so much for +# the boat class. +# +# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to +# plots those masks on top of the original image. This function expects the +# masks to be boolean masks, but our masks above contain probabilities in ``[0, +# 1]``. To get boolean masks, we can do the following: + +class_dim = 1 +boolean_dog_masks = (normalized_masks.argmax(class_dim) == sem_class_to_idx['dog']) +print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}") +show([m.float() for m in boolean_dog_masks]) + + +##################################### +# The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you +# can read it as the following query: "For which pixels is 'dog' the most likely +# class?" +# +# .. note:: +# While we're using the ``normalized_masks`` here, we would have +# gotten the same result by using the non-normalized scores of the model +# directly (as the softmax operation preserves the order). +# +# Now that we have boolean masks, we can use them with +# :func:`~torchvision.utils.draw_segmentation_masks` to plot them on top of the +# original images: + +from torchvision.utils import draw_segmentation_masks dogs_with_masks = [ - draw_segmentation_masks(dog_int, masks=masks, alpha=0.6) - for dog_int, masks in zip((dog1_int, dog2_int), outputs['out']) + draw_segmentation_masks(img, masks=mask, alpha=0.3) + for img, mask in zip(batch, boolean_dog_masks) ] show(dogs_with_masks) + +##################################### +# We can plot more than one mask per image! Remember that the model returned as +# many masks as there are classes. Let's ask the same query as above, but this +# time for *all* classes, not just the dog class: "For each pixel and each class +# C, is class C the most most likely class?" +# +# This one is a bit more involved, so we'll first show how to do it with a +# single image, and then we'll generalize to the batch + +num_classes = normalized_masks.shape[1] +dog1_masks = normalized_masks[0] +class_dim = 0 +dog1_all_classes_masks = dog1_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None] + +print(f"dog1_masks shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}") +print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_all_classes_masks.dtype}") + +dog_with_all_masks = draw_segmentation_masks(dog1, masks=dog1_all_classes_masks, alpha=.4) +show(dog_with_all_masks) + +##################################### +# We can see in the image above that only 2 masks were drawn: the mask for the +# background and the mask for the dog. This is because the model thinks that +# only these 2 classes are the most likely ones across all the pixels. If the +# model had detected another class as the most likely among other pixels, we +# would have seen its mask above. +# +# Removing the background mask is as simple as passing +# ``masks=dog1_all_classes_masks[1:]``, because the background class is the +# class with index 0. +# +# Let's now do the same but for an entire batch of images. The code is similar +# but involves a bit more juggling with the dimensions. + +class_dim = 1 +all_classes_masks = normalized_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None, None] +print(f"shape = {all_classes_masks.shape}, dtype = {all_classes_masks.dtype}") +# The first dimension is the classes now, so we need to swap it +all_classes_masks = all_classes_masks.swapaxes(0, 1) + +dogs_with_masks = [ + draw_segmentation_masks(img, masks=mask, alpha=.4) + for img, mask in zip(batch, all_classes_masks) +] +show(dogs_with_masks) + + +##################################### +# Instance segmentation models +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Instance segmentation models have a significantly different output from the +# semantic segmentation models. We will see here how to plot the masks for such +# models. Let's start by analyzing the output of a Mask-RCNN model. Note that +# these models don't require the images to be normalized, so we don't need to +# use the normalized batch. + +from torchvision.models.detection import maskrcnn_resnet50_fpn +model = maskrcnn_resnet50_fpn(pretrained=True, progress=False) +model = model.eval() + +output = model(batch) +print(output) + +##################################### +# Let's break this down. For each image in the batch, the model outputs some +# detections (or instances). The number of detection varies for each input +# image. Each instance is described by its bounding box, its label, its score +# and its mask. +# +# The way the output is organized is as follows: the output is a list of length +# ``batch_size``. Each entry in the list corresponds to an input image, and it +# is a dict with keys 'boxes', 'labels', 'scores', and 'masks'. Each value +# associated to those keys has ``num_instances`` elements in it. In our case +# above there are 3 instances detected in the first image, and 2 instances in +# the second one. +# +# The boxes can be plotted with :func:`~torchvision.utils.draw_bounding_boxes` +# as above, but here we're more interested in the masks. These masks are quite +# different from the masks that we saw above for the semantic segmentation +# models. + +dog1_output = output[0] +dog1_masks = dog1_output['masks'] +print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, " + f"min = {dog1_masks.min()}, max = {dog1_masks.max()}") + +##################################### +# Here the masks corresponds to probabilities indicating, for each pixel, how +# likely it is to belong to the predicted label of that instance. Those +# predicted labels correspond to the 'labels' element in the same output dict. +# Let's see which labels were predicted for the instances of the first image. + +inst_classes = [ + '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +] + +inst_class_to_idx = {cls: idx for (idx, cls) in enumerate(inst_classes)} + +print("For the first dog, the following instances were detected:") +print([inst_classes[label] for label in dog1_output['labels']]) + +##################################### +# Interestingly, the model detects two persons in the image. Let's go ahead and +# plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks` +# expects boolean masks, we need to convert those probabilities into boolean +# values. Remember that the semantic of those masks is "How likely is this pixel +# to belong to the predicted class?". As a result, a natural way of converting +# those masks into boolean values is to threshold them with the 0.5 probability +# (one could also choose a different threshold). + +proba_threshold = 0.5 +dog1_bool_masks = dog1_output['masks'] > proba_threshold +print(f"shape = {dog1_bool_masks.shape}, dtype = {dog1_bool_masks.dtype}") + +# There's an extra dimension (1) to the masks. We need to remove it +dog1_bool_masks = dog1_bool_masks.squeeze(1) + +show(draw_segmentation_masks(dog1, dog1_bool_masks, alpha=0.1)) + +##################################### +# The model seems to have properly detected the dog, but it also confused trees +# with people. Looking more closely at the scores will help us plotting more +# relevant masks: + +print(dog1_output['scores']) + +##################################### +# Clearly the model is less confident about the dog detection than it is about +# the people detections. That's good news. When plotting the masks, we can ask +# for only those that have a good score. Let's use a score threshold of .75 +# here, and also plot the masks of the second dog. + +score_threshold = .75 + +boolean_masks = [ + out['masks'][out['scores'] > score_threshold] > proba_threshold + for out in output +] + +dogs_with_masks = [ + draw_segmentation_masks(img, mask.squeeze(1)) + for img, mask in zip(batch, boolean_masks) +] +show(dogs_with_masks) + +##################################### +# The two 'people' masks in the first image where not selected because they have +# a lower score than the score threshold. Similarly in the second image, the +# instance with class 15 (which corresponds to 'bench') was not selected. diff --git a/test/assets/fakedata/draw_segm_masks_colors_util.png b/test/assets/fakedata/draw_segm_masks_colors_util.png deleted file mode 100644 index 454b3555631..00000000000 Binary files a/test/assets/fakedata/draw_segm_masks_colors_util.png and /dev/null differ diff --git a/test/assets/fakedata/draw_segm_masks_no_colors_util.png b/test/assets/fakedata/draw_segm_masks_no_colors_util.png deleted file mode 100644 index f048d2469d2..00000000000 Binary files a/test/assets/fakedata/draw_segm_masks_no_colors_util.png and /dev/null differ diff --git a/test/test_utils.py b/test/test_utils.py index 8c4cc620229..7e6ac7b2025 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import os import sys @@ -7,7 +8,7 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F -from PIL import Image, __version__ as PILLOW_VERSION +from PIL import Image, __version__ as PILLOW_VERSION, ImageColor PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) @@ -151,7 +152,7 @@ def test_draw_boxes_vanilla(self): def test_draw_invalid_boxes(self): img_tp = ((1, 1, 1), (1, 2, 3)) - img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) + img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.int) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) @@ -159,55 +160,132 @@ def test_draw_invalid_boxes(self): self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes) self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes) - def test_draw_segmentation_masks_colors(self): - img = torch.full((3, 5, 5), 255, dtype=torch.uint8) - img_cp = img.clone() - masks_cp = masks.clone() - colors = ["#FF00FF", (0, 255, 0), "red"] - result = utils.draw_segmentation_masks(img, masks, colors=colors) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "fakedata", "draw_segm_masks_colors_util.png") - - if not os.path.exists(path): - res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) - res.save(path) - - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) - # Check if modification is not in place - self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) - self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item()) - - def test_draw_segmentation_masks_no_colors(self): - img = torch.full((3, 20, 20), 255, dtype=torch.uint8) - img_cp = img.clone() - masks_cp = masks.clone() - result = utils.draw_segmentation_masks(img, masks, colors=None) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "fakedata", "draw_segm_masks_no_colors_util.png") - - if not os.path.exists(path): - res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) - res.save(path) - - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) - # Check if modification is not in place - self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) - self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item()) - - def test_draw_invalid_masks(self): - img_tp = ((1, 1, 1), (1, 2, 3)) - img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) - img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) - img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8) - self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks) +@pytest.mark.parametrize('fill', (False, True)) +def test_draw_bounding_boxes_int_vs_float(fill): + """Make sure float and uint8 dtypes produce similar images""" + h, w = 500, 500 + img_int = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + img_float = F.convert_image_dtype(img_int, torch.float) + + boxes = torch.tensor([[50, 50, 100, 200], [210, 150, 350, 430]], dtype=torch.float) + out_int = utils.draw_bounding_boxes(image=img_int, boxes=boxes, colors=['red', 'blue'], fill=fill) + out_float = utils.draw_bounding_boxes(image=img_float, boxes=boxes, colors=['red', 'blue'], fill=fill) + + assert out_int.dtype == img_int.dtype + assert out_float.dtype == img_float.dtype + + out_float_int = F.convert_image_dtype(out_float, torch.uint8).int() + out_int = out_int.int() + + assert (out_int - out_float_int).abs().max() <= 1 + + +@pytest.mark.parametrize('dtype', (torch.float, torch.uint8)) +@pytest.mark.parametrize('colors', [ + None, + ['red', 'blue'], + ['#FF00FF', (1, 34, 122)], +]) +@pytest.mark.parametrize('alpha', (0, .5, .7, 1)) +def test_draw_segmentation_masks(dtype, colors, alpha): + """This test makes sure that masks draw their corresponding color where they should""" + num_masks, h, w = 2, 100, 100 + img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) + + # For testing we enforce that there's no overlap between the masks. The + # current behaviour is that the last mask's color will take priority when + # masks overlap, but this makes testing slightly harder so we don't really + # care + overlap = masks[0] & masks[1] + masks[:, overlap] = False + + out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha) + assert out.dtype == dtype + assert out is not img + + # Make sure the image didn't change where there's no mask + masked_pixels = masks[0] | masks[1] + assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all() + + if colors is None: + colors = utils._generate_color_palette(num_masks) + + # Make sure each mask draws with its own color + for mask, color in zip(masks, colors): + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=dtype) + if dtype == torch.float: + color /= 255 + + if alpha == 0: + assert (out[:, mask] == color[:, None]).all() + elif alpha == 1: + assert (out[:, mask] == img[:, mask]).all() + + interpolated_color = (img[:, mask] * alpha + color[:, None] * (1 - alpha)) + max_diff = (out[:, mask] - interpolated_color).abs().max() + if dtype == torch.uint8: + assert max_diff <= 1 + else: + assert max_diff <= 1e-5 + + +def test_draw_segmentation_masks_int_vs_float(): + """Make sure float and uint8 dtypes produce similar images""" + h, w = 100, 100 + masks = torch.randint(0, 2, size=(2, h, w), dtype=torch.bool) + img_int = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + img_float = F.convert_image_dtype(img_int, torch.float) + + out_int = utils.draw_segmentation_masks(image=img_int, masks=masks, colors=['red', 'blue']) + out_float = utils.draw_segmentation_masks(image=img_float, masks=masks, colors=['red', 'blue']) + + assert out_int.dtype == img_int.dtype + assert out_float.dtype == img_float.dtype + + out_float_int = F.convert_image_dtype(out_float, torch.uint8).int() + out_int = out_int.int() + + assert (out_int - out_float_int).abs().max() <= 1 + + +def test_draw_segmentation_masks_errors(): + h, w = 10, 10 + + masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool) + img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + + with pytest.raises(TypeError, match="The image must be a tensor"): + utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks) + with pytest.raises(ValueError, match="The image dtype must be"): + img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64) + utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks) + with pytest.raises(ValueError, match="Pass individual images, not batches"): + batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8) + utils.draw_segmentation_masks(image=batch, masks=masks) + with pytest.raises(ValueError, match="Pass an RGB image"): + one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8) + utils.draw_segmentation_masks(image=one_channel, masks=masks) + with pytest.raises(ValueError, match="The masks must be of dtype bool"): + masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float) + utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype) + with pytest.raises(ValueError, match="masks must be of shape"): + masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool) + utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) + with pytest.raises(ValueError, match="must have the same height and width"): + masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool) + utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) + with pytest.raises(ValueError, match="There are more masks"): + utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) + with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): + bad_colors = np.array(['red', 'blue']) # should be a list + utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) + with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): + bad_colors = ('red', 'blue') # should be a list + utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) if __name__ == '__main__': diff --git a/torchvision/utils.py b/torchvision/utils.py index 9d9bbdb3c80..ebaec520da1 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -136,7 +136,7 @@ def save_image( im.save(fp, format=format) -@torch.no_grad() +# @torch.no_grad() def draw_bounding_boxes( image: torch.Tensor, boxes: torch.Tensor, @@ -150,7 +150,7 @@ def draw_bounding_boxes( """ Draws bounding boxes on given image. - The values of the input image should be uint8 between 0 and 255. + The values of the input image should be uint8 between 0 and 255 or float between 0 and 1. If fill is True, Resulting Tensor should be saved as PNG image. Args: @@ -174,46 +174,51 @@ def draw_bounding_boxes( if not isinstance(image, torch.Tensor): raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dtype not in (torch.uint8, torch.float): + raise ValueError(f"Tensor with dtype uint8 or float expected, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") + dtype = image.dtype ndarr = image.permute(1, 2, 0).numpy() + if dtype == torch.float: + ndarr = np.clip(ndarr * 255, 0, 255).astype(np.uint8) img_to_draw = Image.fromarray(ndarr) img_boxes = boxes.to(torch.int64).tolist() - if fill: - draw = ImageDraw.Draw(img_to_draw, "RGBA") + if colors is None: + colors = [None] * len(img_boxes) + if labels is None: + labels = [None] * len(img_boxes) - else: - draw = ImageDraw.Draw(img_to_draw) + draw = ImageDraw.Draw(img_to_draw, "RGBA") txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) - for i, bbox in enumerate(img_boxes): - if colors is None: - color = None - else: - color = colors[i] + for bbox, color, label in zip(img_boxes, colors, labels): if fill: if color is None: - fill_color = (255, 255, 255, 100) + fill_color = (255, 255, 255) elif isinstance(color, str): # This will automatically raise Error if rgb cannot be parsed. - fill_color = ImageColor.getrgb(color) + (100,) + fill_color = ImageColor.getrgb(color) elif isinstance(color, tuple): - fill_color = color + (100,) - draw.rectangle(bbox, width=width, outline=color, fill=fill_color) + fill_color = color + fill_color = fill_color + (100,) else: - draw.rectangle(bbox, width=width, outline=color) + fill_color = None - if labels is not None: - draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) + draw.rectangle(bbox, width=width, outline=color, fill=fill_color) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + if label is not None: + draw.text((bbox[0], bbox[1]), label, fill=color, font=txt_font) + + out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=dtype) + if dtype == torch.float: + out /= 255 + return out @torch.no_grad() @@ -226,52 +231,73 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The values of the input image should be uint8 between 0 and 255, or float values between 0 and 1. Args: - image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. - masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class. - alpha (float): Float number between 0 and 1 denoting factor of transparency of masks. - colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can - be represented as `str` or `Tuple[int, int, int]`. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + colors (list or None): List containing the colors of the masks. The colors can + be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list + with one element. By default, random colors are generated for each mask. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted. + img (Tensor[C, H, W]): Image Tensor with the same dtype as the input image, with segmentation masks + drawn on top. """ if not isinstance(image, torch.Tensor): - raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype not in (torch.uint8, torch.float): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") num_masks = masks.size()[0] - masks = masks.argmax(0) + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") if colors is None: - palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette - color_arr = (colors_t % 255).numpy().astype("uint8") - else: - color_list = [] - for color in colors: - if isinstance(color, str): - # This will automatically raise Error if rgb cannot be parsed. - fill_color = ImageColor.getrgb(color) - color_list.append(fill_color) - elif isinstance(color, tuple): - color_list.append(color) - - color_arr = np.array(color_list).astype("uint8") - - _, h, w = image.size() - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) - img_to_draw.putpalette(color_arr) - - img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) - img_to_draw = img_to_draw.permute((2, 0, 1)) - - return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) + colors = _generate_color_palette(num_masks) + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = image.dtype + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=out_dtype) + if out_dtype == torch.float: + color /= 255 + colors_.append(color) + + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * alpha + img_to_draw * (1 - alpha) + return out.to(out_dtype) + + +def _generate_color_palette(num_masks): + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + return [tuple((i * palette) % 255) for i in range(num_masks)]