Description
🚀 Feature
We added utilities to plot bounding boxes #2556 #2785 and utility to plot semantic segmentation masks #3272 #3330 #3820
I think this is probably the last plotting utility we need.
Motivation
Same as before, we need these to make visualization and post-processing easier.
Also Easily support torchvision models.
Pitch
Use only PIL to draw key points. Also, keep the function signature consistent with previous utilities.
A prototype implementation is here!
@torch.no_grad()
def draw_keypoints(
image: torch.Tensor,
keypoints: torch.Tensor,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
radius: Optional[int] = 5,
connect : Optional[bool] = False,
plot_ids Optional[bool] = False,
font: Optional[str] = None,
font_size: int = 10
) -> torch.Tensor:
"""
Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
keypoints (Tensor): The ketpoints from Keypoint RCNN
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of keypoints. The colors can
be represented as `str` or `Tuple[int, int, int]`.
radius (int): Radius of circles drawn for keypoints.
connect (bool): If True connects the visible keypoints.
plot_ids (bool): If True plots the visible keypoints ids on the image.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.
"""
Radius to determine how big the key point should be. This is essential as torchvision models suppose flexible image sizes.
The font is to plot the keypoint id. Sometimes in models like PoseNet, we would like to know which keypoint ids are visible.
Might help in debugging if particular keypoint does not work well.
The output of Keypoint RCNN is as follows
boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format,
with values of x between 0 and W and values of y between 0 and H
labels (Int64Tensor[N]): the class label for each ground-truth box
keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances,
in the format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
I think we should plot only the keypoints which are visible,
Additional context
The boxes can be plotted using draw_bounding_boxes
. Users can use both these utils together to plot both boxes and keypoints. I don't think so draw_keypoints
should internally call / draw boxes.
Probably we need to give a demo run at it