@@ -748,6 +748,138 @@ kernel void ps_roi_align<DTYPE>( \
748
748
REGISTER_PS_ROI_ALIGN_OP(float);
749
749
REGISTER_PS_ROI_ALIGN_OP(half);
750
750
751
+ template<typename T>
752
+ kernel void ps_roi_align_backward(
753
+ constant T * grad_output [[buffer(0)]],
754
+ constant T * rois [[buffer(1)]],
755
+ constant int64_t * channel_mapping [[buffer(2)]],
756
+ device T * grad_input [[buffer(3)]],
757
+ constant int64_t & output_size [[buffer(4)]],
758
+ constant int64_t & channels [[buffer(5)]],
759
+ constant int64_t & height [[buffer(6)]],
760
+ constant int64_t & width [[buffer(7)]],
761
+ constant int64_t & pooled_height [[buffer(8)]],
762
+ constant int64_t & pooled_width [[buffer(9)]],
763
+ constant int64_t & sampling_ratio [[buffer(10)]],
764
+ constant int64_t & channels_out [[buffer(11)]],
765
+ constant float & spatial_scale [[buffer(12)]],
766
+ uint2 tgid [[threadgroup_position_in_grid]],
767
+ uint2 tptg [[threads_per_threadgroup]],
768
+ uint2 tid2 [[thread_position_in_threadgroup]]){
769
+
770
+ MPS_1D_KERNEL_LOOP(index, output_size, 1) {
771
+ // (n, *, ph, pw) is an element in the pooled output
772
+ int pw = index % pooled_width;
773
+ int ph = (index / pooled_width) % pooled_height;
774
+ int n = index / pooled_width / pooled_height / channels_out;
775
+
776
+ constant T* offset_rois = rois + n * 5;
777
+ int roi_batch_ind = offset_rois[0];
778
+
779
+ // Do not using rounding; this implementation detail is critical
780
+ T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
781
+ T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
782
+ T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
783
+ T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
784
+
785
+ // Force too small ROIs to be 1x1
786
+ T roi_width = roi_end_w - roi_start_w;
787
+ T roi_height = roi_end_h - roi_start_h;
788
+ T bin_size_h = roi_height / static_cast<T>(pooled_height);
789
+ T bin_size_w = roi_width / static_cast<T>(pooled_width);
790
+
791
+ int c_in = channel_mapping[index];
792
+
793
+ // Do not using floor/ceil; this implementation detail is critical
794
+ T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
795
+ T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
796
+
797
+ const T grad_output_this_bin = grad_output[index];
798
+
799
+ // We use roi_bin_grid to sample the grid and mimic integral
800
+ int roi_bin_grid_h = (sampling_ratio > 0)
801
+ ? sampling_ratio
802
+ : ceil(roi_height / pooled_height); // e.g., = 2
803
+ int roi_bin_grid_w =
804
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
805
+ const T count = roi_bin_grid_h * roi_bin_grid_w;
806
+
807
+ const int offset = (roi_batch_ind * channels + c_in) * height * width;
808
+
809
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
810
+ const T y = hstart +
811
+ static_cast<T>(iy + .5f) * bin_size_h /
812
+ static_cast<T>(roi_bin_grid_h);
813
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
814
+ const T x = wstart +
815
+ static_cast<T>(ix + .5f) * bin_size_w /
816
+ static_cast<T>(roi_bin_grid_w);
817
+
818
+ T w1, w2, w3, w4;
819
+ int x_low, x_high, y_low, y_high;
820
+
821
+ bilinear_interpolate_gradient(
822
+ height,
823
+ width,
824
+ y,
825
+ x,
826
+ w1,
827
+ w2,
828
+ w3,
829
+ w4,
830
+ x_low,
831
+ x_high,
832
+ y_low,
833
+ y_high,
834
+ index);
835
+
836
+ T g1 = grad_output_this_bin * w1 / count;
837
+ T g2 = grad_output_this_bin * w2 / count;
838
+ T g3 = grad_output_this_bin * w3 / count;
839
+ T g4 = grad_output_this_bin * w4 / count;
840
+
841
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
842
+ device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_low);
843
+ device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_high);
844
+ device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_low);
845
+ device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_high);
846
+
847
+ // atomic_float data type is supported on Metal 3 onward.
848
+ // TODO: Use native atomic_fetch_add_explicit for Metal 3.
849
+ atomic_add_float(xAtomic, static_cast<T>(g1));
850
+ atomic_add_float(yAtomic, static_cast<T>(g2));
851
+ atomic_add_float(zAtomic, static_cast<T>(g3));
852
+ atomic_add_float(wAtomic, static_cast<T>(g4));
853
+ } // if
854
+ } // ix
855
+ } // iy
856
+ }
857
+ }
858
+
859
+ #define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \
860
+ template \
861
+ [[host_name("ps_roi_align_backward_" #DTYPE)]] \
862
+ kernel void ps_roi_align_backward<DTYPE>( \
863
+ constant DTYPE * grad_output [[buffer(0)]], \
864
+ constant DTYPE * rois [[buffer(1)]], \
865
+ constant int64_t * channel_mapping [[buffer(2)]], \
866
+ device DTYPE * grad_input [[buffer(3)]], \
867
+ constant int64_t & output_size [[buffer(4)]], \
868
+ constant int64_t & channels [[buffer(5)]], \
869
+ constant int64_t & height [[buffer(6)]], \
870
+ constant int64_t & width [[buffer(7)]], \
871
+ constant int64_t & pooled_height [[buffer(8)]], \
872
+ constant int64_t & pooled_width [[buffer(9)]], \
873
+ constant int64_t & sampling_ratio [[buffer(10)]], \
874
+ constant int64_t & channels_out [[buffer(11)]], \
875
+ constant float & spatial_scale [[buffer(12)]], \
876
+ uint2 tgid [[threadgroup_position_in_grid]], \
877
+ uint2 tptg [[threads_per_threadgroup]], \
878
+ uint2 tid2 [[thread_position_in_threadgroup]]);
879
+
880
+ REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float);
881
+ REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half);
882
+
751
883
)VISION_METAL" ;
752
884
753
885
static id<MTLLibrary> compileBinaryOpsLibrary (id<MTLDevice> device) {
0 commit comments