Skip to content

Commit 0f5b9a8

Browse files
Awni Hannunfaisalmemon
authored andcommitted
Add batch offsets for mx.fast.rope (ml-explore#2564)
* implement batch rope for Metal * cuda rope (ml-explore#2576)
1 parent 203742f commit 0f5b9a8

File tree

7 files changed

+231
-153
lines changed

7 files changed

+231
-153
lines changed

mlx/backend/cuda/rope.cu

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
103103
__device__ void rope_impl(
104104
const T* in,
105105
T* out,
106-
int offset,
106+
const int* offset,
107107
float inv_freq,
108108
float scale,
109109
const cuda::std::array<int64_t, 3> strides,
110110
const cuda::std::array<int64_t, 3> out_strides,
111-
int64_t n_batch,
111+
int64_t offset_stride,
112+
int n_head,
112113
uint3 pos,
113114
uint3 dims) {
114-
float L = scale * static_cast<float>(pos.y + offset);
115+
auto n_head_up = N * ((n_head + N - 1) / N);
116+
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
117+
auto batch_idx = (pos.z * N) / n_head_up;
118+
auto batch_offset = offset[batch_idx * offset_stride];
119+
float L = scale * static_cast<float>(pos.y + batch_offset);
120+
auto mat_idx = batch_idx * n_head + head_idx;
115121

116122
// Compute costheta, sintheta
117123
float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
123129
size_t out_index_1, out_index_2;
124130
if (traditional) {
125131
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
126-
N * pos.z * out_strides[0];
132+
mat_idx * out_strides[0];
127133
out_index_2 = out_index_1 + 1;
128134
in_index_1 =
129-
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
135+
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
130136
in_index_2 = in_index_1 + strides[2];
131137
} else {
132138
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
133-
N * pos.z * out_strides[0];
139+
mat_idx * out_strides[0];
134140
out_index_2 = out_index_1 + dims.x * out_strides[2];
135-
in_index_1 =
136-
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
141+
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
137142
in_index_2 = in_index_1 + dims.x * strides[2];
138143
}
139-
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
144+
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
140145
// Read and write the output
141146
float x1 = static_cast<float>(in[in_index_1]);
142147
float x2 = static_cast<float>(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
167172
float base,
168173
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
169174
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
170-
int64_t n_batch,
175+
int64_t offset_stride,
176+
int n_head,
171177
uint3 dims) {
172178
uint3 pos = make_uint3(
173179
blockIdx.x * blockDim.x + threadIdx.x,
@@ -182,12 +188,13 @@ __global__ void rope(
182188
rope_impl<T, traditional, forward>(
183189
in,
184190
out,
185-
*offset,
191+
offset,
186192
inv_freq,
187193
scale,
188194
strides,
189195
out_strides,
190-
n_batch,
196+
offset_stride,
197+
n_head,
191198
pos,
192199
dims);
193200
}
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
202209
float base,
203210
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
204211
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
205-
int64_t n_batch,
212+
int64_t offset_stride,
213+
int n_head,
206214
uint3 dims,
207215
int64_t freq_stride) {
208216
uint3 pos = make_uint3(
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
217225
rope_impl<T, traditional, forward>(
218226
in,
219227
out,
220-
*offset,
228+
offset,
221229
inv_freq,
222230
scale,
223231
strides,
224232
out_strides,
225-
n_batch,
233+
offset_stride,
234+
n_head,
226235
pos,
227236
dims);
228237
}
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
245254
auto& offset = inputs[1];
246255
auto& out = outputs[0];
247256

248-
if (in.ndim() < 3) {
249-
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
250-
}
251-
252257
cuda::std::array<int64_t, 3> strides;
253258
cuda::std::array<int64_t, 3> out_strides;
254259
bool donated = false;
255260
int ndim = in.ndim();
256-
int dispatch_ndim = in.ndim();
261+
262+
int B = in.shape(0);
263+
int T = in.shape(-2);
264+
int D = in.shape(-1);
265+
size_t mat_size = T * D;
266+
int dispatch_ndim = ndim;
257267
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
258268
dispatch_ndim--;
259269
}
260-
size_t mat_size = in.shape(-2) * in.shape(-1);
270+
271+
int N = 1;
272+
for (int i = 1; i < (ndim - 2); ++i) {
273+
N *= in.shape(i);
274+
}
261275

262276
// We apply rope to less that the whole vector so copy to output and then
263277
// apply in-place.
264-
if (dims_ < in.shape(-1)) {
278+
if (dims_ < D) {
265279
donated = true;
266280
auto ctype =
267281
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
302316
out_strides[2] = out.strides()[ndim - 1];
303317

304318
// Some flags to help us dispatch below
305-
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
319+
bool single = in.flags().row_contiguous && B == 1 && T == 1;
306320
bool with_freqs = inputs.size() == 3;
307321

308322
auto& encoder = cu::get_command_encoder(s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
319333
if (single && !with_freqs) {
320334
auto kernel =
321335
cu::rope_single<DataType, traditional.value, forward.value>;
322-
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
336+
uint2 dims = make_uint2(dims_ / 2, N);
323337
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
324338
encoder.add_kernel_node(
325339
kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
336350
} else if (single) {
337351
auto kernel =
338352
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
339-
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
353+
uint2 dims = make_uint2(dims_ / 2, N);
340354
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
341355
encoder.add_kernel_node(
342356
kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
354368
} else if (with_freqs) {
355369
auto kernel =
356370
cu::rope_freqs<DataType, traditional.value, forward.value>;
357-
uint3 dims =
358-
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
359-
dims.z = (dims.z + 3) / 4;
371+
int n_per_thread = 4;
372+
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
373+
uint3 dims = make_uint3(dims_ / 2, T, dimz);
360374
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
375+
int64_t offset_stride = 0;
376+
if (inputs[1].ndim() > 0) {
377+
offset_stride = inputs[1].strides()[0];
378+
}
361379
encoder.add_kernel_node(
362380
kernel,
363381
grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
371389
std::log2(base_),
372390
strides,
373391
out_strides,
374-
in.size() / mat_size,
392+
offset_stride,
393+
N,
375394
dims,
376395
inputs[2].strides(0));
377396
} else {
378397
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
379-
uint3 dims =
380-
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
381-
dims.z = (dims.z + 3) / 4;
398+
int n_per_thread = 4;
399+
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
400+
uint3 dims = make_uint3(dims_ / 2, T, dimz);
382401
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
402+
int64_t offset_stride = 0;
403+
if (inputs[1].ndim() > 0) {
404+
offset_stride = inputs[1].strides()[0];
405+
}
383406
encoder.add_kernel_node(
384407
kernel,
385408
grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
392415
std::log2(base_),
393416
strides,
394417
out_strides,
395-
in.size() / mat_size,
418+
offset_stride,
419+
N,
396420
dims);
397421
}
398422
});

0 commit comments

Comments
 (0)