Skip to content

Commit 131ba13

Browse files
authored
Fix wrong clamping in RoIAlign with aligned=True (#2438)
* Fix wrong clamping in RoIAlign with aligned=True * Fix silly mistake * Bugfix pointed out during code-review
1 parent 0a8586c commit 131ba13

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

torchvision/csrc/cpu/ROIAlign_cpu.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,13 @@ void ROIAlignForward(
141141
T roi_end_w = offset_rois[3] * spatial_scale - offset;
142142
T roi_end_h = offset_rois[4] * spatial_scale - offset;
143143

144-
// Force malformed ROIs to be 1x1
145-
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
146-
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
144+
T roi_width = roi_end_w - roi_start_w;
145+
T roi_height = roi_end_h - roi_start_h;
146+
if (!aligned) {
147+
// Force malformed ROIs to be 1x1
148+
roi_width = std::max(roi_width, (T)1.);
149+
roi_height = std::max(roi_height, (T)1.);
150+
}
147151

148152
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
149153
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
@@ -309,9 +313,13 @@ void ROIAlignBackward(
309313
T roi_end_w = offset_rois[3] * spatial_scale - offset;
310314
T roi_end_h = offset_rois[4] * spatial_scale - offset;
311315

312-
// Force malformed ROIs to be 1x1
313-
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
314-
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
316+
T roi_width = roi_end_w - roi_start_w;
317+
T roi_height = roi_end_h - roi_start_h;
318+
if (!aligned) {
319+
// Force malformed ROIs to be 1x1
320+
roi_width = std::max(roi_width, (T)1.);
321+
roi_height = std::max(roi_height, (T)1.);
322+
}
315323

316324
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
317325
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

torchvision/csrc/cuda/ROIAlign_cuda.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,13 @@ __global__ void RoIAlignForward(
9191
T roi_end_w = offset_rois[3] * spatial_scale - offset;
9292
T roi_end_h = offset_rois[4] * spatial_scale - offset;
9393

94-
// Force malformed ROIs to be 1x1
95-
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
96-
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
94+
T roi_width = roi_end_w - roi_start_w;
95+
T roi_height = roi_end_h - roi_start_h;
96+
if (!aligned) {
97+
// Force malformed ROIs to be 1x1
98+
roi_width = max(roi_width, (T)1.);
99+
roi_height = max(roi_height, (T)1.);
100+
}
97101

98102
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
99103
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
@@ -229,9 +233,13 @@ __global__ void RoIAlignBackward(
229233
T roi_end_w = offset_rois[3] * spatial_scale - offset;
230234
T roi_end_h = offset_rois[4] * spatial_scale - offset;
231235

232-
// Force malformed ROIs to be 1x1
233-
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
234-
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
236+
T roi_width = roi_end_w - roi_start_w;
237+
T roi_height = roi_end_h - roi_start_h;
238+
if (!aligned) {
239+
// Force malformed ROIs to be 1x1
240+
roi_width = max(roi_width, (T)1.);
241+
roi_height = max(roi_height, (T)1.);
242+
}
235243

236244
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
237245
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

0 commit comments

Comments
 (0)