Skip to content

Update docstring of the heatmaps_to_keypoints #8028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,47 @@ def _onnx_heatmaps_to_keypoints_loop(


def heatmaps_to_keypoints(maps, rois):
"""Extract predicted keypoint locations from heatmaps. Output has shape
(#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
for each keypoint.
"""Extract predicted keypoint locations from heatmaps.

This function takes as input heatmaps and regions of interest (ROIs) and extracts
predicted keypoint locations. The output has two tensors:

- The first tensor has the shape (#rois, #keypoints, 3), where each row corresponds
to a keypoint and contains the following information:
- x: The x-coordinate of the keypoint location in the image.
- y: The y-coordinate of the keypoint location in the image.
- vis: The visibility score of the keypoint (usually 1 for visible keypoints).

- The second tensor has shape (#rois, #keypoints), where each value corresponds to
the logit score of a keypoint. Logits are raw scores that represent the confidence
of the model in predicting the keypoint's location.

Args:
maps (torch.Tensor): A tensor containing heatmaps for keypoints. It has the shape
(#rois, #keypoints, height, width).
rois (torch.Tensor): A tensor containing regions of interest (ROIs) for which
keypoints are predicted. Each row in `rois` represents an ROI and should
have the format (x1, y1, x2, y2), where (x1, y1) are the coordinates of the
top-left corner and (x2, y2) are the coordinates of the bottom-right corner.

Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing two tensors:
- xy_preds (torch.Tensor): The predicted keypoint locations with shape
(#rois, #keypoints, 3), where the last dimension corresponds to (x, y, vis).
- end_scores (torch.Tensor): The confidence scores (logits) for each predicted
keypoint with shape (#rois, #keypoints).

The `xy_preds` tensor contains the estimated keypoint locations in the image coordinate
system, while the `end_scores` tensor provides confidence scores for each keypoint.
You can use the `xy_preds` to obtain the actual keypoint locations in the image, and
the `end_scores` to assess the model's confidence in its predictions.

Note:
- The `vis` value in `xy_preds` typically represents the visibility of the keypoint,
where 1 indicates a visible keypoint and 0 indicates an occluded or invisible
keypoint.
- You can threshold the `end_scores` to filter out keypoints with low confidence.

"""
# This function converts a discrete image coordinate in a HEATMAP_SIZE x
# HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
Expand Down