Skip to content

Commit 64735ab

Browse files
committed
added relevant headers for ROIAlign backwards
1 parent 3722e66 commit 64735ab

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

torchvision/csrc/ROIAlign.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
#endif
88

99
// Interface for Python
10-
at::Tensor ROIAlign_forward(const at::Tensor& input,
11-
const at::Tensor& rois,
12-
const float spatial_scale,
13-
const int pooled_height,
14-
const int pooled_width,
15-
const int sampling_ratio) {
10+
at::Tensor ROIAlign_forward(const at::Tensor& input, // Input feature map.
11+
const at::Tensor& rois, // List of ROIs to pool over.
12+
const float spatial_scale, // The scale of the image features. ROIs will be scaled to this.
13+
const int pooled_height, // The height of the pooled feature map.
14+
const int pooled_width, // The width of the pooled feature
15+
const int sampling_ratio) // The number of points to sample in each bin along each axis.
16+
{
1617
if (input.type().is_cuda()) {
1718
#ifdef WITH_CUDA
1819
return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
@@ -40,6 +41,6 @@ at::Tensor ROIAlign_backward(const at::Tensor& grad,
4041
AT_ERROR("Not compiled with GPU support");
4142
#endif
4243
}
43-
AT_ERROR("Not implemented on the CPU");
44+
return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
4445
}
4546

torchvision/csrc/cpu/vision.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor &input,
2525
const int pooled_width,
2626
const int sampling_ratio);
2727

28+
at::Tensor ROIAlign_backward_cpu(const at::Tensor &grad,
29+
const at::Tensor &rois,
30+
const float spatial_scale,
31+
const int pooled_height,
32+
const int pooled_width,
33+
const int batch_size,
34+
const int channels,
35+
const int height,
36+
const int width,
37+
const int sampling_ratio);
38+
2839
at::Tensor nms_cpu(const at::Tensor &dets,
2940
const at::Tensor &scores,
3041
const float threshold);

0 commit comments

Comments
 (0)