Skip to content

Commit 6f32285

Browse files
committed
ps_roi_align bw (failed prec)
1 parent 195d03a commit 6f32285

File tree

2 files changed

+159
-26
lines changed

2 files changed

+159
-26
lines changed

torchvision/csrc/ops/mps/ps_roi_align_kernel.mm

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -111,29 +111,30 @@
111111
at::Tensor ps_roi_align_backward_kernel(
112112
const at::Tensor& grad,
113113
const at::Tensor& rois,
114+
const at::Tensor& channel_mapping,
114115
double spatial_scale,
115116
int64_t pooled_height,
116117
int64_t pooled_width,
118+
int64_t sampling_ratio,
117119
int64_t batch_size,
118120
int64_t channels,
119121
int64_t height,
120-
int64_t width,
121-
int64_t sampling_ratio,
122-
bool aligned) {
122+
int64_t width) {
123123

124124
using namespace at::native::mps;
125125
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
126126
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
127+
TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor");
127128

128-
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2};
129+
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3};
129130

130131
at::CheckedFrom c = "ps_roi_align_backward_kernel";
131-
at::checkAllSameGPU(c, {grad_t, rois_t});
132+
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
132133
at::checkAllSameType(c, {grad_t, rois_t});
133134

134135
float spatial_scale_f = static_cast<float>(spatial_scale);
135136

136-
at::Tensor grad_input = at::zeros(
137+
auto grad_input = at::zeros(
137138
{batch_size, channels, height, width}, grad.options());
138139

139140
if (grad.numel() == 0) {
@@ -146,11 +147,14 @@
146147
int64_t w_stride = grad.stride(3);
147148
int64_t output_size = grad.numel();
148149

150+
int64_t channels_out = channels / (pooled_height * pooled_width);
151+
149152
at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel");
150-
auto rois_ = rois.contiguous();
153+
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
151154

152-
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
155+
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad_);
153156
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
157+
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
154158
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
155159
id<MTLDevice> device = MPSDevice::getInstance()->device();
156160
MPSStream* mpsStream = getCurrentMPSStream();
@@ -167,23 +171,20 @@
167171

168172
[computeEncoder setComputePipelineState:binaryPSO];
169173
// [N, C, H, W]
170-
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
174+
[computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0];
171175
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
172-
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2];
176+
[computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:2];
177+
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
173178

174-
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
175-
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
176-
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
177-
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
178-
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
179-
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
180-
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
181-
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
182-
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
183-
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12];
184-
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13];
185-
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
186-
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];
179+
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
180+
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
181+
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
182+
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
183+
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
184+
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
185+
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
186+
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
187+
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];
187188

188189
// A threadGroup is equivalent to a cuda's block.
189190
NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup;
@@ -206,9 +207,9 @@
206207
m.impl(
207208
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
208209
TORCH_FN(ps_roi_align_forward_kernel));
209-
//m.impl(
210-
// TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
211-
// TORCH_FN(ps_roi_align_backward_kernel));
210+
m.impl(
211+
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
212+
TORCH_FN(ps_roi_align_backward_kernel));
212213
}
213214

214215
} // namespace ops

torchvision/csrc/ops/mps/vision_kernels.h

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,138 @@ kernel void ps_roi_align<DTYPE>( \
748748
REGISTER_PS_ROI_ALIGN_OP(float);
749749
REGISTER_PS_ROI_ALIGN_OP(half);
750750
751+
template<typename T>
752+
kernel void ps_roi_align_backward(
753+
constant T * grad_output [[buffer(0)]],
754+
constant T * rois [[buffer(1)]],
755+
constant int64_t * channel_mapping [[buffer(2)]],
756+
device T * grad_input [[buffer(3)]],
757+
constant int64_t & output_size [[buffer(4)]],
758+
constant int64_t & channels [[buffer(5)]],
759+
constant int64_t & height [[buffer(6)]],
760+
constant int64_t & width [[buffer(7)]],
761+
constant int64_t & pooled_height [[buffer(8)]],
762+
constant int64_t & pooled_width [[buffer(9)]],
763+
constant int64_t & sampling_ratio [[buffer(10)]],
764+
constant int64_t & channels_out [[buffer(11)]],
765+
constant float & spatial_scale [[buffer(12)]],
766+
uint2 tgid [[threadgroup_position_in_grid]],
767+
uint2 tptg [[threads_per_threadgroup]],
768+
uint2 tid2 [[thread_position_in_threadgroup]]){
769+
770+
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
771+
// (n, *, ph, pw) is an element in the pooled output
772+
int pw = index % pooled_width;
773+
int ph = (index / pooled_width) % pooled_height;
774+
int n = index / pooled_width / pooled_height / channels_out;
775+
776+
constant T* offset_rois = rois + n * 5;
777+
int roi_batch_ind = offset_rois[0];
778+
779+
// Do not using rounding; this implementation detail is critical
780+
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
781+
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
782+
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
783+
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
784+
785+
// Force too small ROIs to be 1x1
786+
T roi_width = roi_end_w - roi_start_w;
787+
T roi_height = roi_end_h - roi_start_h;
788+
T bin_size_h = roi_height / static_cast<T>(pooled_height);
789+
T bin_size_w = roi_width / static_cast<T>(pooled_width);
790+
791+
int c_in = channel_mapping[index];
792+
793+
// Do not using floor/ceil; this implementation detail is critical
794+
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
795+
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
796+
797+
const T grad_output_this_bin = grad_output[index];
798+
799+
// We use roi_bin_grid to sample the grid and mimic integral
800+
int roi_bin_grid_h = (sampling_ratio > 0)
801+
? sampling_ratio
802+
: ceil(roi_height / pooled_height); // e.g., = 2
803+
int roi_bin_grid_w =
804+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
805+
const T count = roi_bin_grid_h * roi_bin_grid_w;
806+
807+
const int offset = (roi_batch_ind * channels + c_in) * height * width;
808+
809+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
810+
const T y = hstart +
811+
static_cast<T>(iy + .5f) * bin_size_h /
812+
static_cast<T>(roi_bin_grid_h);
813+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
814+
const T x = wstart +
815+
static_cast<T>(ix + .5f) * bin_size_w /
816+
static_cast<T>(roi_bin_grid_w);
817+
818+
T w1, w2, w3, w4;
819+
int x_low, x_high, y_low, y_high;
820+
821+
bilinear_interpolate_gradient(
822+
height,
823+
width,
824+
y,
825+
x,
826+
w1,
827+
w2,
828+
w3,
829+
w4,
830+
x_low,
831+
x_high,
832+
y_low,
833+
y_high,
834+
index);
835+
836+
T g1 = grad_output_this_bin * w1 / count;
837+
T g2 = grad_output_this_bin * w2 / count;
838+
T g3 = grad_output_this_bin * w3 / count;
839+
T g4 = grad_output_this_bin * w4 / count;
840+
841+
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
842+
device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_low);
843+
device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_high);
844+
device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_low);
845+
device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_high);
846+
847+
// atomic_float data type is supported on Metal 3 onward.
848+
// TODO: Use native atomic_fetch_add_explicit for Metal 3.
849+
atomic_add_float(xAtomic, static_cast<T>(g1));
850+
atomic_add_float(yAtomic, static_cast<T>(g2));
851+
atomic_add_float(zAtomic, static_cast<T>(g3));
852+
atomic_add_float(wAtomic, static_cast<T>(g4));
853+
} // if
854+
} // ix
855+
} // iy
856+
}
857+
}
858+
859+
#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \
860+
template \
861+
[[host_name("ps_roi_align_backward_" #DTYPE)]] \
862+
kernel void ps_roi_align_backward<DTYPE>( \
863+
constant DTYPE * grad_output [[buffer(0)]], \
864+
constant DTYPE * rois [[buffer(1)]], \
865+
constant int64_t * channel_mapping [[buffer(2)]], \
866+
device DTYPE * grad_input [[buffer(3)]], \
867+
constant int64_t & output_size [[buffer(4)]], \
868+
constant int64_t & channels [[buffer(5)]], \
869+
constant int64_t & height [[buffer(6)]], \
870+
constant int64_t & width [[buffer(7)]], \
871+
constant int64_t & pooled_height [[buffer(8)]], \
872+
constant int64_t & pooled_width [[buffer(9)]], \
873+
constant int64_t & sampling_ratio [[buffer(10)]], \
874+
constant int64_t & channels_out [[buffer(11)]], \
875+
constant float & spatial_scale [[buffer(12)]], \
876+
uint2 tgid [[threadgroup_position_in_grid]], \
877+
uint2 tptg [[threads_per_threadgroup]], \
878+
uint2 tid2 [[thread_position_in_threadgroup]]);
879+
880+
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float);
881+
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half);
882+
751883
)VISION_METAL";
752884

753885
static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {

0 commit comments

Comments
 (0)