Skip to content

Make R-CNN models support Automatic Mixed Precision (AMP) #2222

@Okery

Description

@Okery

🚀 Feature

Now PyTorch 1.6.0 has torch.cuda.amp.autocast, I think we can make R-CNN models support Automatic Mixed Precision (AMP).

Motivation

When AMP is enabled, the training speed may increase ~20% on GPUs that support FP16.

Alternatives

There are 2 modifications:

  • In torchvision/ops/roi_align.py, function roi_align
    rois' datatype should be the same as input's datatype. So I replace
    check_roi_boxes_shape(boxes)
    rois = boxes
    output_size = _pair(output_size)
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
    return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
                                           output_size[0], output_size[1],
                                           sampling_ratio, aligned)

with

    check_roi_boxes_shape(boxes)
    rois = boxes
    output_size = _pair(output_size)
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
    rois = rois.to(input.dtype)
    return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
                                           output_size[0], output_size[1],
                                           sampling_ratio, aligned)
  • In torchvision/models/detection/_utils.py, function encode_boxes
    I'm confused about the decorator torch.jit.script. I printed the proposals' datatype, but the output was 6, instead of torch.float32 or torch.float16. So I removed the decorator.

Additional context

I tested the speed of maskrcnn_resnet50_fpn with and without autocast().
Dataset: VOC 2012 Segmentation, train 1463 images, val 1444 images.

GPU train/test FPS without AMP train/test FPS with AMP increase
2080 Ti 7.4/13.3 9.3/15.2 25.2%/14.6%
1080 Ti 5.1/8.8 4.2/7.3 --

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions