diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 51b210cb6f3..cbca6b8364b 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -232,9 +232,47 @@ def _onnx_heatmaps_to_keypoints_loop( def heatmaps_to_keypoints(maps, rois): - """Extract predicted keypoint locations from heatmaps. Output has shape - (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) - for each keypoint. + """Extract predicted keypoint locations from heatmaps. + + This function takes as input heatmaps and regions of interest (ROIs) and extracts + predicted keypoint locations. The output has two tensors: + + - The first tensor has the shape (#rois, #keypoints, 3), where each row corresponds + to a keypoint and contains the following information: + - x: The x-coordinate of the keypoint location in the image. + - y: The y-coordinate of the keypoint location in the image. + - vis: The visibility score of the keypoint (usually 1 for visible keypoints). + + - The second tensor has shape (#rois, #keypoints), where each value corresponds to + the logit score of a keypoint. Logits are raw scores that represent the confidence + of the model in predicting the keypoint's location. + + Args: + maps (torch.Tensor): A tensor containing heatmaps for keypoints. It has the shape + (#rois, #keypoints, height, width). + rois (torch.Tensor): A tensor containing regions of interest (ROIs) for which + keypoints are predicted. Each row in `rois` represents an ROI and should + have the format (x1, y1, x2, y2), where (x1, y1) are the coordinates of the + top-left corner and (x2, y2) are the coordinates of the bottom-right corner. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing two tensors: + - xy_preds (torch.Tensor): The predicted keypoint locations with shape + (#rois, #keypoints, 3), where the last dimension corresponds to (x, y, vis). + - end_scores (torch.Tensor): The confidence scores (logits) for each predicted + keypoint with shape (#rois, #keypoints). + + The `xy_preds` tensor contains the estimated keypoint locations in the image coordinate + system, while the `end_scores` tensor provides confidence scores for each keypoint. + You can use the `xy_preds` to obtain the actual keypoint locations in the image, and + the `end_scores` to assess the model's confidence in its predictions. + + Note: + - The `vis` value in `xy_preds` typically represents the visibility of the keypoint, + where 1 indicates a visible keypoint and 0 indicates an occluded or invisible + keypoint. + - You can threshold the `end_scores` to filter out keypoints with low confidence. + """ # This function converts a discrete image coordinate in a HEATMAP_SIZE x # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain