diff --git a/.gitignore b/.gitignore index 3c7e579c23c..e6e4e0f3728 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ gen.yml .idea/ *.orig *-checkpoint.ipynb +*.venv diff --git a/test/assets/fakedata/draw_boxes_util.png b/test/assets/fakedata/draw_boxes_util.png index e6b9286bf92..d64fa2f1f36 100644 Binary files a/test/assets/fakedata/draw_boxes_util.png and b/test/assets/fakedata/draw_boxes_util.png differ diff --git a/test/test_utils.py b/test/test_utils.py index 21e2ab461d7..662ad2a0cce 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -86,7 +86,7 @@ def test_draw_boxes(self): [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) labels = ["a", "b", "c", "d"] colors = ["green", "#FF00FF", (0, 255, 0), "red"] - result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors) + result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png") if not os.path.exists(path): diff --git a/torchvision/utils.py b/torchvision/utils.py index 6290809a7d6..9ee5a0cc65c 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -4,8 +4,7 @@ import math import warnings import numpy as np -from PIL import Image, ImageDraw -from PIL import ImageFont +from PIL import Image, ImageDraw, ImageFont, ImageColor __all__ = ["make_grid", "save_image", "draw_bounding_boxes"] @@ -142,6 +141,7 @@ def draw_bounding_boxes( boxes: torch.Tensor, labels: Optional[List[str]] = None, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, + fill: Optional[bool] = False, width: int = 1, font: Optional[str] = None, font_size: int = 10 @@ -150,6 +150,7 @@ def draw_bounding_boxes( """ Draws bounding boxes on given image. The values of the input image should be uint8 between 0 and 255. + If filled, Resulting Tensor should be saved as PNG image. Args: image (Tensor): Tensor of shape (C x H x W) @@ -159,6 +160,7 @@ def draw_bounding_boxes( labels (List[str]): List containing the labels of bounding boxes. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can be represented as `str` or `Tuple[int, int, int]`. + fill (bool): If `True` fills the bounding box with specified color. width (int): Width of bounding box. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, @@ -178,12 +180,31 @@ def draw_bounding_boxes( img_boxes = boxes.to(torch.int64).tolist() - draw = ImageDraw.Draw(img_to_draw) + if fill: + draw = ImageDraw.Draw(img_to_draw, "RGBA") + + else: + draw = ImageDraw.Draw(img_to_draw) + txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) for i, bbox in enumerate(img_boxes): - color = None if colors is None else colors[i] - draw.rectangle(bbox, width=width, outline=color) + if colors is None: + color = None + else: + color = colors[i] + + if fill: + if color is None: + fill_color = (255, 255, 255, 100) + elif isinstance(color, str): + # This will automatically raise Error if rgb cannot be parsed. + fill_color = ImageColor.getrgb(color) + (100,) + elif isinstance(color, tuple): + fill_color = color + (100,) + draw.rectangle(bbox, width=width, outline=color, fill=fill_color) + else: + draw.rectangle(bbox, width=width, outline=color) if labels is not None: draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)