|
4 | 4 | =======================
|
5 | 5 |
|
6 | 6 | This example illustrates some of the utilities that torchvision offers for
|
7 |
| -visualizing images, bounding boxes, and segmentation masks. |
| 7 | +visualizing images, bounding boxes, segmentation masks and keypoints. |
8 | 8 | """
|
9 | 9 |
|
10 |
| -# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail.png" |
| 10 | +# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png" |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 | import numpy as np
|
@@ -366,3 +366,110 @@ def show(imgs):
|
366 | 366 | # The two 'people' masks in the first image where not selected because they have
|
367 | 367 | # a lower score than the score threshold. Similarly in the second image, the
|
368 | 368 | # instance with class 15 (which corresponds to 'bench') was not selected.
|
| 369 | + |
| 370 | +##################################### |
| 371 | +# Visualizing keypoints |
| 372 | +# ------------------------------ |
| 373 | +# The :func:`~torchvision.utils.draw_keypoints` function can be used to |
| 374 | +# draw keypoints on images. We will see how to use it with |
| 375 | +# torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`. |
| 376 | +# We will first have a look at output of the model. |
| 377 | +# |
| 378 | +# Note that the keypoint detection model does not need normalized images. |
| 379 | +# |
| 380 | + |
| 381 | +from torchvision.models.detection import keypointrcnn_resnet50_fpn |
| 382 | +from torchvision.io import read_image |
| 383 | + |
| 384 | +person_int = read_image(str(Path("assets") / "person1.jpg")) |
| 385 | +person_float = convert_image_dtype(person_int, dtype=torch.float) |
| 386 | + |
| 387 | +model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False) |
| 388 | +model = model.eval() |
| 389 | + |
| 390 | +outputs = model([person_float]) |
| 391 | +print(outputs) |
| 392 | + |
| 393 | +##################################### |
| 394 | +# As we see the output contains a list of dictionaries. |
| 395 | +# The output list is of length batch_size. |
| 396 | +# We currently have just a single image so length of list is 1. |
| 397 | +# Each entry in the list corresponds to an input image, |
| 398 | +# and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`. |
| 399 | +# Each value associated to those keys has `num_instances` elements in it. |
| 400 | +# In our case above there are 2 instances detected in the image. |
| 401 | + |
| 402 | +kpts = outputs[0]['keypoints'] |
| 403 | +scores = outputs[0]['scores'] |
| 404 | + |
| 405 | +print(kpts) |
| 406 | +print(scores) |
| 407 | + |
| 408 | +##################################### |
| 409 | +# The KeypointRCNN model detects there are two instances in the image. |
| 410 | +# If you plot the boxes by using :func:`~draw_bounding_boxes` |
| 411 | +# you would recognize they are the person and the surfboard. |
| 412 | +# If we look at the scores, we will realize that the model is much more confident about the person than surfboard. |
| 413 | +# We could now set a threshold confidence and plot instances which we are confident enough. |
| 414 | +# Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person. |
| 415 | + |
| 416 | +detect_threshold = 0.75 |
| 417 | +idx = torch.where(scores > detect_threshold) |
| 418 | +keypoints = kpts[idx] |
| 419 | + |
| 420 | +print(keypoints) |
| 421 | + |
| 422 | +##################################### |
| 423 | +# Great, now we have the keypoints corresponding to the person. |
| 424 | +# Each keypoint is represented by x, y coordinates and the visibility. |
| 425 | +# We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints. |
| 426 | +# Note that the utility expects uint8 images. |
| 427 | + |
| 428 | +from torchvision.utils import draw_keypoints |
| 429 | + |
| 430 | +res = draw_keypoints(person_int, keypoints, colors="blue", radius=3) |
| 431 | +show(res) |
| 432 | + |
| 433 | +##################################### |
| 434 | +# As we see the keypoints appear as colored circles over the image. |
| 435 | +# The coco keypoints for a person are ordered and represent the following list.\ |
| 436 | + |
| 437 | +coco_keypoints = [ |
| 438 | + "nose", "left_eye", "right_eye", "left_ear", "right_ear", |
| 439 | + "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", |
| 440 | + "left_wrist", "right_wrist", "left_hip", "right_hip", |
| 441 | + "left_knee", "right_knee", "left_ankle", "right_ankle", |
| 442 | +] |
| 443 | + |
| 444 | +##################################### |
| 445 | +# What if we are interested in joining the keypoints? |
| 446 | +# This is especially useful in creating pose detection or action recognition. |
| 447 | +# We can join the keypoints easily using the `connectivity` parameter. |
| 448 | +# A close observation would reveal that we would need to join the points in below |
| 449 | +# order to construct human skeleton. |
| 450 | +# |
| 451 | +# nose -> left_eye -> left_ear. (0, 1), (1, 3) |
| 452 | +# |
| 453 | +# nose -> right_eye -> right_ear. (0, 2), (2, 4) |
| 454 | +# |
| 455 | +# nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9) |
| 456 | +# |
| 457 | +# nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10) |
| 458 | +# |
| 459 | +# left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15) |
| 460 | +# |
| 461 | +# right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16) |
| 462 | +# |
| 463 | +# We will create a list containing these keypoint ids to be connected. |
| 464 | + |
| 465 | +connect_skeleton = [ |
| 466 | + (0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8), |
| 467 | + (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16) |
| 468 | +] |
| 469 | + |
| 470 | +##################################### |
| 471 | +# We pass the above list to the connectivity parameter to connect the keypoints. |
| 472 | +# |
| 473 | + |
| 474 | +res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3) |
| 475 | +show(res) |
0 commit comments