@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
103103__device__ void rope_impl (
104104 const T* in,
105105 T* out,
106- int offset,
106+ const int * offset,
107107 float inv_freq,
108108 float scale,
109109 const cuda::std::array<int64_t , 3 > strides,
110110 const cuda::std::array<int64_t , 3 > out_strides,
111- int64_t n_batch,
111+ int64_t offset_stride,
112+ int n_head,
112113 uint3 pos,
113114 uint3 dims) {
114- float L = scale * static_cast <float >(pos.y + offset);
115+ auto n_head_up = N * ((n_head + N - 1 ) / N);
116+ auto head_idx = static_cast <int >((pos.z * N) % n_head_up);
117+ auto batch_idx = (pos.z * N) / n_head_up;
118+ auto batch_offset = offset[batch_idx * offset_stride];
119+ float L = scale * static_cast <float >(pos.y + batch_offset);
120+ auto mat_idx = batch_idx * n_head + head_idx;
115121
116122 // Compute costheta, sintheta
117123 float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
123129 size_t out_index_1, out_index_2;
124130 if (traditional) {
125131 out_index_1 = 2 * pos.x * out_strides[2 ] + pos.y * out_strides[1 ] +
126- N * pos. z * out_strides[0 ];
132+ mat_idx * out_strides[0 ];
127133 out_index_2 = out_index_1 + 1 ;
128134 in_index_1 =
129- 2 * pos.x * strides[2 ] + pos.y * strides[1 ] + N * pos. z * strides[0 ];
135+ 2 * pos.x * strides[2 ] + pos.y * strides[1 ] + mat_idx * strides[0 ];
130136 in_index_2 = in_index_1 + strides[2 ];
131137 } else {
132138 out_index_1 = pos.x * out_strides[2 ] + pos.y * out_strides[1 ] +
133- N * pos. z * out_strides[0 ];
139+ mat_idx * out_strides[0 ];
134140 out_index_2 = out_index_1 + dims.x * out_strides[2 ];
135- in_index_1 =
136- pos.x * strides[2 ] + pos.y * strides[1 ] + N * pos.z * strides[0 ];
141+ in_index_1 = pos.x * strides[2 ] + pos.y * strides[1 ] + mat_idx * strides[0 ];
137142 in_index_2 = in_index_1 + dims.x * strides[2 ];
138143 }
139- for (int i = 0 ; i < N && pos. z * N + i < n_batch ; ++i) {
144+ for (int i = 0 ; i < N && head_idx + i < n_head ; ++i) {
140145 // Read and write the output
141146 float x1 = static_cast <float >(in[in_index_1]);
142147 float x2 = static_cast <float >(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
167172 float base,
168173 const __grid_constant__ cuda::std::array<int64_t , 3 > strides,
169174 const __grid_constant__ cuda::std::array<int64_t , 3 > out_strides,
170- int64_t n_batch,
175+ int64_t offset_stride,
176+ int n_head,
171177 uint3 dims) {
172178 uint3 pos = make_uint3 (
173179 blockIdx .x * blockDim .x + threadIdx .x ,
@@ -182,12 +188,13 @@ __global__ void rope(
182188 rope_impl<T, traditional, forward>(
183189 in,
184190 out,
185- * offset,
191+ offset,
186192 inv_freq,
187193 scale,
188194 strides,
189195 out_strides,
190- n_batch,
196+ offset_stride,
197+ n_head,
191198 pos,
192199 dims);
193200}
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
202209 float base,
203210 const __grid_constant__ cuda::std::array<int64_t , 3 > strides,
204211 const __grid_constant__ cuda::std::array<int64_t , 3 > out_strides,
205- int64_t n_batch,
212+ int64_t offset_stride,
213+ int n_head,
206214 uint3 dims,
207215 int64_t freq_stride) {
208216 uint3 pos = make_uint3 (
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
217225 rope_impl<T, traditional, forward>(
218226 in,
219227 out,
220- * offset,
228+ offset,
221229 inv_freq,
222230 scale,
223231 strides,
224232 out_strides,
225- n_batch,
233+ offset_stride,
234+ n_head,
226235 pos,
227236 dims);
228237}
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
245254 auto & offset = inputs[1 ];
246255 auto & out = outputs[0 ];
247256
248- if (in.ndim () < 3 ) {
249- throw std::runtime_error (" [RoPE] Input must have at least 3 dimensions" );
250- }
251-
252257 cuda::std::array<int64_t , 3 > strides;
253258 cuda::std::array<int64_t , 3 > out_strides;
254259 bool donated = false ;
255260 int ndim = in.ndim ();
256- int dispatch_ndim = in.ndim ();
261+
262+ int B = in.shape (0 );
263+ int T = in.shape (-2 );
264+ int D = in.shape (-1 );
265+ size_t mat_size = T * D;
266+ int dispatch_ndim = ndim;
257267 while (in.shape (-dispatch_ndim) == 1 && dispatch_ndim > 3 ) {
258268 dispatch_ndim--;
259269 }
260- size_t mat_size = in.shape (-2 ) * in.shape (-1 );
270+
271+ int N = 1 ;
272+ for (int i = 1 ; i < (ndim - 2 ); ++i) {
273+ N *= in.shape (i);
274+ }
261275
262276 // We apply rope to less that the whole vector so copy to output and then
263277 // apply in-place.
264- if (dims_ < in. shape (- 1 ) ) {
278+ if (dims_ < D ) {
265279 donated = true ;
266280 auto ctype =
267281 (in.flags ().row_contiguous ) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
302316 out_strides[2 ] = out.strides ()[ndim - 1 ];
303317
304318 // Some flags to help us dispatch below
305- bool single = in.flags ().row_contiguous && (mat_size == in. shape (- 1 )) ;
319+ bool single = in.flags ().row_contiguous && B == 1 && T == 1 ;
306320 bool with_freqs = inputs.size () == 3 ;
307321
308322 auto & encoder = cu::get_command_encoder (s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
319333 if (single && !with_freqs) {
320334 auto kernel =
321335 cu::rope_single<DataType, traditional.value , forward.value >;
322- uint2 dims = make_uint2 (dims_ / 2 , in. size () / mat_size );
336+ uint2 dims = make_uint2 (dims_ / 2 , N );
323337 auto [grid, block] = get_grid_and_block (dims.x , dims.y , 1 );
324338 encoder.add_kernel_node (
325339 kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
336350 } else if (single) {
337351 auto kernel =
338352 cu::rope_single_freqs<DataType, traditional.value , forward.value >;
339- uint2 dims = make_uint2 (dims_ / 2 , in. size () / mat_size );
353+ uint2 dims = make_uint2 (dims_ / 2 , N );
340354 auto [grid, block] = get_grid_and_block (dims.x , dims.y , 1 );
341355 encoder.add_kernel_node (
342356 kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
354368 } else if (with_freqs) {
355369 auto kernel =
356370 cu::rope_freqs<DataType, traditional.value , forward.value >;
357- uint3 dims =
358- make_uint3 (dims_ / 2 , in. shape (- 2 ), in. size ( ) / mat_size );
359- dims. z = (dims. z + 3 ) / 4 ;
371+ int n_per_thread = 4 ;
372+ uint32_t dimz = B * ((N + n_per_thread - 1 ) / n_per_thread );
373+ uint3 dims = make_uint3 (dims_ / 2 , T, dimz) ;
360374 auto [grid, block] = get_grid_and_block (dims.x , dims.y , dims.z );
375+ int64_t offset_stride = 0 ;
376+ if (inputs[1 ].ndim () > 0 ) {
377+ offset_stride = inputs[1 ].strides ()[0 ];
378+ }
361379 encoder.add_kernel_node (
362380 kernel,
363381 grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
371389 std::log2 (base_),
372390 strides,
373391 out_strides,
374- in.size () / mat_size,
392+ offset_stride,
393+ N,
375394 dims,
376395 inputs[2 ].strides (0 ));
377396 } else {
378397 auto kernel = cu::rope<DataType, traditional.value , forward.value >;
379- uint3 dims =
380- make_uint3 (dims_ / 2 , in. shape (- 2 ), in. size ( ) / mat_size );
381- dims. z = (dims. z + 3 ) / 4 ;
398+ int n_per_thread = 4 ;
399+ uint32_t dimz = B * ((N + n_per_thread - 1 ) / n_per_thread );
400+ uint3 dims = make_uint3 (dims_ / 2 , T, dimz) ;
382401 auto [grid, block] = get_grid_and_block (dims.x , dims.y , dims.z );
402+ int64_t offset_stride = 0 ;
403+ if (inputs[1 ].ndim () > 0 ) {
404+ offset_stride = inputs[1 ].strides ()[0 ];
405+ }
383406 encoder.add_kernel_node (
384407 kernel,
385408 grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
392415 std::log2 (base_),
393416 strides,
394417 out_strides,
395- in.size () / mat_size,
418+ offset_stride,
419+ N,
396420 dims);
397421 }
398422 });
0 commit comments