Skip to content

Commit 30669a5

Browse files
oke-adityapmeierdatumbox
authored
Add gallery example for drawing keypoints (#4892)
* Start writing gallery example * Remove the child image fix implementation add code * add docs * Apply suggestions from code review Co-authored-by: Vasilis Vryniotis <[email protected]> * address review update thumbnail Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 85b7858 commit 30669a5

File tree

6 files changed

+127
-6
lines changed

6 files changed

+127
-6
lines changed

gallery/assets/person1.jpg

68.5 KB
Loading
-187 KB
Binary file not shown.
Loading

gallery/plot_visualization_utils.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
=======================
55
66
This example illustrates some of the utilities that torchvision offers for
7-
visualizing images, bounding boxes, and segmentation masks.
7+
visualizing images, bounding boxes, segmentation masks and keypoints.
88
"""
99

10-
# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail.png"
10+
# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png"
1111

1212
import torch
1313
import numpy as np
@@ -366,3 +366,110 @@ def show(imgs):
366366
# The two 'people' masks in the first image where not selected because they have
367367
# a lower score than the score threshold. Similarly in the second image, the
368368
# instance with class 15 (which corresponds to 'bench') was not selected.
369+
370+
#####################################
371+
# Visualizing keypoints
372+
# ------------------------------
373+
# The :func:`~torchvision.utils.draw_keypoints` function can be used to
374+
# draw keypoints on images. We will see how to use it with
375+
# torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`.
376+
# We will first have a look at output of the model.
377+
#
378+
# Note that the keypoint detection model does not need normalized images.
379+
#
380+
381+
from torchvision.models.detection import keypointrcnn_resnet50_fpn
382+
from torchvision.io import read_image
383+
384+
person_int = read_image(str(Path("assets") / "person1.jpg"))
385+
person_float = convert_image_dtype(person_int, dtype=torch.float)
386+
387+
model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False)
388+
model = model.eval()
389+
390+
outputs = model([person_float])
391+
print(outputs)
392+
393+
#####################################
394+
# As we see the output contains a list of dictionaries.
395+
# The output list is of length batch_size.
396+
# We currently have just a single image so length of list is 1.
397+
# Each entry in the list corresponds to an input image,
398+
# and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`.
399+
# Each value associated to those keys has `num_instances` elements in it.
400+
# In our case above there are 2 instances detected in the image.
401+
402+
kpts = outputs[0]['keypoints']
403+
scores = outputs[0]['scores']
404+
405+
print(kpts)
406+
print(scores)
407+
408+
#####################################
409+
# The KeypointRCNN model detects there are two instances in the image.
410+
# If you plot the boxes by using :func:`~draw_bounding_boxes`
411+
# you would recognize they are the person and the surfboard.
412+
# If we look at the scores, we will realize that the model is much more confident about the person than surfboard.
413+
# We could now set a threshold confidence and plot instances which we are confident enough.
414+
# Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person.
415+
416+
detect_threshold = 0.75
417+
idx = torch.where(scores > detect_threshold)
418+
keypoints = kpts[idx]
419+
420+
print(keypoints)
421+
422+
#####################################
423+
# Great, now we have the keypoints corresponding to the person.
424+
# Each keypoint is represented by x, y coordinates and the visibility.
425+
# We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints.
426+
# Note that the utility expects uint8 images.
427+
428+
from torchvision.utils import draw_keypoints
429+
430+
res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
431+
show(res)
432+
433+
#####################################
434+
# As we see the keypoints appear as colored circles over the image.
435+
# The coco keypoints for a person are ordered and represent the following list.\
436+
437+
coco_keypoints = [
438+
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
439+
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
440+
"left_wrist", "right_wrist", "left_hip", "right_hip",
441+
"left_knee", "right_knee", "left_ankle", "right_ankle",
442+
]
443+
444+
#####################################
445+
# What if we are interested in joining the keypoints?
446+
# This is especially useful in creating pose detection or action recognition.
447+
# We can join the keypoints easily using the `connectivity` parameter.
448+
# A close observation would reveal that we would need to join the points in below
449+
# order to construct human skeleton.
450+
#
451+
# nose -> left_eye -> left_ear. (0, 1), (1, 3)
452+
#
453+
# nose -> right_eye -> right_ear. (0, 2), (2, 4)
454+
#
455+
# nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9)
456+
#
457+
# nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10)
458+
#
459+
# left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15)
460+
#
461+
# right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16)
462+
#
463+
# We will create a list containing these keypoint ids to be connected.
464+
465+
connect_skeleton = [
466+
(0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8),
467+
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16)
468+
]
469+
470+
#####################################
471+
# We pass the above list to the connectivity parameter to connect the keypoints.
472+
#
473+
474+
res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
475+
show(res)

test/test_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,14 @@ def test_draw_keypoints_vanilla():
256256

257257
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
258258
img_cp = img.clone()
259-
result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),))
259+
result = utils.draw_keypoints(
260+
img,
261+
keypoints,
262+
colors="red",
263+
connectivity=[
264+
(0, 1),
265+
],
266+
)
260267
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
261268
if not os.path.exists(path):
262269
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
@@ -277,7 +284,14 @@ def test_draw_keypoints_colored(colors):
277284

278285
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
279286
img_cp = img.clone()
280-
result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),))
287+
result = utils.draw_keypoints(
288+
img,
289+
keypoints,
290+
colors=colors,
291+
connectivity=[
292+
(0, 1),
293+
],
294+
)
281295
assert result.size(0) == 3
282296
assert_equal(keypoints, keypoints_cp)
283297
assert_equal(img, img_cp)

torchvision/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def draw_segmentation_masks(
304304
def draw_keypoints(
305305
image: torch.Tensor,
306306
keypoints: torch.Tensor,
307-
connectivity: Optional[Tuple[Tuple[int, int]]] = None,
307+
connectivity: Optional[List[Tuple[int, int]]] = None,
308308
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
309309
radius: int = 2,
310310
width: int = 3,
@@ -318,7 +318,7 @@ def draw_keypoints(
318318
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
319319
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
320320
in the format [x, y].
321-
connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where,
321+
connectivity (List[Tuple[int, int]]]): A List of tuple where,
322322
each tuple contains pair of keypoints to be connected.
323323
colors (str, Tuple): The color can be represented as
324324
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.

0 commit comments

Comments
 (0)