Skip to content

Commit d99e4d5

Browse files
committed
support for Half types in ROIAlign
1 parent 110d998 commit d99e4d5

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

torchvision/csrc/cpu/ROIAlign_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
369369
if (output.numel() == 0)
370370
return output;
371371

372-
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
372+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
373373
ROIAlignForward<scalar_t>(
374374
output_size,
375375
input.data<scalar_t>(),
@@ -414,7 +414,7 @@ at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad,
414414
int h_stride = grad.stride(2);
415415
int w_stride = grad.stride(3);
416416

417-
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_forward", [&] {
417+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] {
418418
ROIAlignBackward<scalar_t>(
419419
grad.numel(),
420420
grad.data<scalar_t>(),

torchvision/csrc/cuda/ROIAlign_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
273273
return output;
274274
}
275275

276-
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
276+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
277277
RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
278278
output_size,
279279
input.data<scalar_t>(),
@@ -323,7 +323,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
323323
int h_stride = grad.stride(2);
324324
int w_stride = grad.stride(3);
325325

326-
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] {
326+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] {
327327
RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
328328
grad.numel(),
329329
grad.data<scalar_t>(),

0 commit comments

Comments
 (0)