Skip to content

Commit e9528f6

Browse files
authored
[Kernel] support merge_attn_states CUDA kernel, 3x speedup (vllm-project#16173)
Signed-off-by: DefTruth <[email protected]>
1 parent 51baa9c commit e9528f6

File tree

10 files changed

+519
-4
lines changed

10 files changed

+519
-4
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
230230
"csrc/cache_kernels.cu"
231231
"csrc/attention/paged_attention_v1.cu"
232232
"csrc/attention/paged_attention_v2.cu"
233+
"csrc/attention/merge_attn_states.cu"
233234
"csrc/pos_encoding_kernels.cu"
234235
"csrc/activation_kernels.cu"
235236
"csrc/layernorm_kernels.cu"
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#include <optional>
2+
#include <torch/all.h>
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <c10/cuda/CUDAGuard.h>
5+
#include <algorithm>
6+
7+
#include "attention_dtypes.h"
8+
#include "attention_utils.cuh"
9+
10+
namespace vllm {
11+
12+
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
13+
// can be used to combine partial attention results (in the split-KV case)
14+
template <typename scalar_t, const uint NUM_THREADS>
15+
__global__ void merge_attn_states_kernel(
16+
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
17+
const float* prefix_lse, const scalar_t* suffix_output,
18+
const float* suffix_lse, const uint num_tokens, const uint num_heads,
19+
const uint head_size) {
20+
using pack_128b_t = uint4;
21+
const uint pack_size = 16 / sizeof(scalar_t);
22+
const uint threads_per_head = head_size / pack_size;
23+
24+
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
25+
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
26+
27+
if (global_idx >= token_head_threads) return;
28+
29+
// global_idx -> token_idx + head_idx + pack_idx
30+
const uint token_head_idx = global_idx / threads_per_head;
31+
const uint pack_idx = global_idx % threads_per_head;
32+
33+
const uint token_idx = token_head_idx / num_heads;
34+
const uint head_idx = token_head_idx % num_heads;
35+
36+
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
37+
const uint head_offset =
38+
token_idx * num_heads * head_size + head_idx * head_size;
39+
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
40+
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
41+
scalar_t* output_head_ptr = output + head_offset;
42+
43+
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
44+
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
45+
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
46+
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
47+
48+
const float max_lse = fmaxf(p_lse, s_lse);
49+
p_lse = p_lse - max_lse;
50+
s_lse = s_lse - max_lse;
51+
const float p_se = expf(p_lse);
52+
const float s_se = expf(s_lse);
53+
const float out_se = p_se + s_se;
54+
const float p_scale = p_se / out_se;
55+
const float s_scale = s_se / out_se;
56+
57+
if (pack_offset < head_size) {
58+
// Pack 128b load
59+
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
60+
prefix_head_ptr)[pack_offset / pack_size];
61+
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
62+
suffix_head_ptr)[pack_offset / pack_size];
63+
pack_128b_t o_out_pack;
64+
65+
#pragma unroll
66+
for (uint i = 0; i < pack_size; ++i) {
67+
// Always use float for FMA to keep high precision.
68+
// half(uint16_t), bfloat16, float -> float.
69+
const float p_out_f =
70+
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
71+
const float s_out_f =
72+
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
73+
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
74+
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
75+
// float -> half(uint16_t), bfloat16, float.
76+
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
77+
}
78+
79+
// Pack 128b storage
80+
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
81+
o_out_pack;
82+
}
83+
// We only need to write to output_lse once per head.
84+
if (output_lse != nullptr && pack_idx == 0) {
85+
float out_lse = logf(out_se) + max_lse;
86+
output_lse[head_idx * num_tokens + token_idx] = out_lse;
87+
}
88+
}
89+
90+
} // namespace vllm
91+
92+
// The following macro is used to dispatch the conversion function based on
93+
// the output data type. The FN is a macro that calls a function with
94+
// template<typename scalar_t>.
95+
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
96+
{ \
97+
if (scalar_dtype == at::ScalarType::Float) { \
98+
fn(float); \
99+
} else if (scalar_dtype == at::ScalarType::Half) { \
100+
fn(uint16_t); \
101+
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
102+
fn(__nv_bfloat16); \
103+
} else { \
104+
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
105+
} \
106+
}
107+
108+
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
109+
{ \
110+
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
111+
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
112+
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
113+
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
114+
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
115+
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
116+
num_heads, head_size); \
117+
}
118+
119+
/*@brief Merges the attention states from prefix and suffix
120+
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
121+
*
122+
* @param output [n,h,d] The output tensor to store the merged attention states.
123+
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
124+
* @param prefix_output [n,h,d] The prefix attention states.
125+
* @param prefix_lse [h,d] The log-sum-exp values for the prefix attention
126+
* states.
127+
* @param suffix_output [n,h,d] The suffix attention states.
128+
* @param suffix_lse [h,d] The log-sum-exp values for the suffix attention
129+
* states.
130+
*/
131+
template <typename scalar_t>
132+
void merge_attn_states_launcher(torch::Tensor& output,
133+
std::optional<torch::Tensor> output_lse,
134+
const torch::Tensor& prefix_output,
135+
const torch::Tensor& prefix_lse,
136+
const torch::Tensor& suffix_output,
137+
const torch::Tensor& suffix_lse) {
138+
constexpr uint NUM_THREADS = 128;
139+
const uint num_tokens = output.size(0);
140+
const uint num_heads = output.size(1);
141+
const uint head_size = output.size(2);
142+
const uint pack_size = 16 / sizeof(scalar_t);
143+
TORCH_CHECK(head_size % pack_size == 0,
144+
"headsize must be multiple of pack_size:", pack_size);
145+
float* output_lse_ptr = nullptr;
146+
if (output_lse.has_value()) {
147+
output_lse_ptr = output_lse.value().data_ptr<float>();
148+
}
149+
// process one pack elements per thread. float -> 4, half/bf16 -> 8
150+
const uint threads_per_head = head_size / pack_size;
151+
const uint total_threads = num_tokens * num_heads * threads_per_head;
152+
153+
dim3 block(NUM_THREADS);
154+
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
155+
156+
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
157+
}
158+
159+
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
160+
{ \
161+
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
162+
prefix_lse, suffix_output, \
163+
suffix_lse); \
164+
}
165+
166+
void merge_attn_states(torch::Tensor& output,
167+
std::optional<torch::Tensor> output_lse,
168+
const torch::Tensor& prefix_output,
169+
const torch::Tensor& prefix_lse,
170+
const torch::Tensor& suffix_output,
171+
const torch::Tensor& suffix_lse) {
172+
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
173+
}

csrc/ops.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ void paged_attention_v2(
5252
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
5353
const int64_t blocksparse_head_sliding_step);
5454

55+
#ifndef USE_ROCM
56+
void merge_attn_states(torch::Tensor& output,
57+
std::optional<torch::Tensor> output_lse,
58+
const torch::Tensor& prefix_output,
59+
const torch::Tensor& prefix_lse,
60+
const torch::Tensor& suffix_output,
61+
const torch::Tensor& suffix_lse);
62+
#endif
63+
5564
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
5665
double epsilon);
5766

csrc/torch_bindings.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6464
" int blocksparse_head_sliding_step) -> ()");
6565
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
6666

67+
#ifndef USE_ROCM
68+
// Merge attn states
69+
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
70+
// can be used to combine partial attention results (in the split-KV case)
71+
ops.def(
72+
"merge_attn_states("
73+
" Tensor! output,"
74+
" Tensor!? output_lse,"
75+
" Tensor prefix_output,"
76+
" Tensor prefix_lse,"
77+
" Tensor suffix_output,"
78+
" Tensor suffix_lse) -> ()");
79+
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
80+
#endif
81+
6782
// Activation ops
6883
// Activation function used in SwiGLU.
6984
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");

0 commit comments

Comments
 (0)