-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add utility to draw keypoints #4216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b2f6615
4fb038d
deda5d7
5490821
4cfc220
6306746
e8c93cf
73abf84
2ebe0c7
77afa81
cdeebdd
fad0d44
76af22e
1291f44
b9af874
28ebcc0
8db61f7
ebe7a25
e6afa37
c46d9db
691562c
4a65900
d5747d3
a034137
1f41550
d9d96cb
e6e7428
c8da898
4385643
8997b58
0002677
949e42c
693ffdc
00db80b
41ecc44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import torch | ||
from PIL import Image, ImageDraw, ImageFont, ImageColor | ||
|
||
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] | ||
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks", "draw_keypoints"] | ||
|
||
|
||
@torch.no_grad() | ||
|
@@ -300,6 +300,76 @@ def draw_segmentation_masks( | |
return out.to(out_dtype) | ||
|
||
|
||
@torch.no_grad() | ||
def draw_keypoints( | ||
image: torch.Tensor, | ||
keypoints: torch.Tensor, | ||
connectivity: Optional[Tuple[Tuple[int, int]]] = None, | ||
colors: Optional[Union[str, Tuple[int, int, int]]] = None, | ||
radius: int = 2, | ||
width: int = 3, | ||
) -> torch.Tensor: | ||
|
||
""" | ||
Draws Keypoints on given RGB image. | ||
The values of the input image should be uint8 between 0 and 255. | ||
|
||
Args: | ||
image (Tensor): Tensor of shape (3, H, W) and dtype uint8. | ||
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
in the format [x, y]. | ||
connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, | ||
each tuple contains pair of keypoints to be connected. | ||
colors (str, Tuple): The color can be represented as | ||
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | ||
radius (int): Integer denoting radius of keypoint. | ||
width (int): Integer denoting width of line connecting keypoints. | ||
|
||
Returns: | ||
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. | ||
""" | ||
|
||
if not isinstance(image, torch.Tensor): | ||
raise TypeError(f"The image must be a tensor, got {type(image)}") | ||
elif image.dtype != torch.uint8: | ||
raise ValueError(f"The image dtype must be uint8, got {image.dtype}") | ||
elif image.dim() != 3: | ||
raise ValueError("Pass individual images, not batches") | ||
elif image.size()[0] != 3: | ||
raise ValueError("Pass an RGB image. Other Image formats are not supported") | ||
|
||
if keypoints.ndim != 3: | ||
raise ValueError("keypoints must be of shape (num_instances, K, 2)") | ||
|
||
ndarr = image.permute(1, 2, 0).numpy() | ||
img_to_draw = Image.fromarray(ndarr) | ||
draw = ImageDraw.Draw(img_to_draw) | ||
img_kpts = keypoints.to(torch.int64).tolist() | ||
|
||
for kpt_id, kpt_inst in enumerate(img_kpts): | ||
for inst_id, kpt in enumerate(kpt_inst): | ||
x1 = kpt[0] - radius | ||
x2 = kpt[0] + radius | ||
y1 = kpt[1] - radius | ||
y2 = kpt[1] + radius | ||
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) | ||
|
||
if connectivity: | ||
for connection in connectivity: | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
start_pt_x = kpt_inst[connection[0]][0] | ||
start_pt_y = kpt_inst[connection[0]][1] | ||
|
||
end_pt_x = kpt_inst[connection[1]][0] | ||
end_pt_y = kpt_inst[connection[1]][1] | ||
|
||
draw.line( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What should be the line color? Should there be a parameter for line color? Or should it be same as color of keypoints? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the interest of getting this merged soon, let's leave it white for now, and in the future create an issue to enable this to be configurable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we have sufficient time before the next release. Let's discuss about this as well. |
||
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), | ||
width=width, | ||
) | ||
|
||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) | ||
|
||
|
||
def _generate_color_palette(num_masks: int): | ||
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) | ||
return [tuple((i * palette) % 255) for i in range(num_masks)] | ||
|
Uh oh!
There was an error while loading. Please reload this page.