|
| 1 | +#include <optional> |
| 2 | +#include <torch/all.h> |
| 3 | +#include <ATen/cuda/CUDAContext.h> |
| 4 | +#include <c10/cuda/CUDAGuard.h> |
| 5 | +#include <algorithm> |
| 6 | + |
| 7 | +#include "attention_dtypes.h" |
| 8 | +#include "attention_utils.cuh" |
| 9 | + |
| 10 | +namespace vllm { |
| 11 | + |
| 12 | +// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 |
| 13 | +// can be used to combine partial attention results (in the split-KV case) |
| 14 | +template <typename scalar_t, const uint NUM_THREADS> |
| 15 | +__global__ void merge_attn_states_kernel( |
| 16 | + scalar_t* output, float* output_lse, const scalar_t* prefix_output, |
| 17 | + const float* prefix_lse, const scalar_t* suffix_output, |
| 18 | + const float* suffix_lse, const uint num_tokens, const uint num_heads, |
| 19 | + const uint head_size) { |
| 20 | + using pack_128b_t = uint4; |
| 21 | + const uint pack_size = 16 / sizeof(scalar_t); |
| 22 | + const uint threads_per_head = head_size / pack_size; |
| 23 | + |
| 24 | + const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; |
| 25 | + const uint token_head_threads = num_tokens * num_heads * threads_per_head; |
| 26 | + |
| 27 | + if (global_idx >= token_head_threads) return; |
| 28 | + |
| 29 | + // global_idx -> token_idx + head_idx + pack_idx |
| 30 | + const uint token_head_idx = global_idx / threads_per_head; |
| 31 | + const uint pack_idx = global_idx % threads_per_head; |
| 32 | + |
| 33 | + const uint token_idx = token_head_idx / num_heads; |
| 34 | + const uint head_idx = token_head_idx % num_heads; |
| 35 | + |
| 36 | + const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. |
| 37 | + const uint head_offset = |
| 38 | + token_idx * num_heads * head_size + head_idx * head_size; |
| 39 | + const scalar_t* prefix_head_ptr = prefix_output + head_offset; |
| 40 | + const scalar_t* suffix_head_ptr = suffix_output + head_offset; |
| 41 | + scalar_t* output_head_ptr = output + head_offset; |
| 42 | + |
| 43 | + float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; |
| 44 | + float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; |
| 45 | + p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse; |
| 46 | + s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse; |
| 47 | + |
| 48 | + const float max_lse = fmaxf(p_lse, s_lse); |
| 49 | + p_lse = p_lse - max_lse; |
| 50 | + s_lse = s_lse - max_lse; |
| 51 | + const float p_se = expf(p_lse); |
| 52 | + const float s_se = expf(s_lse); |
| 53 | + const float out_se = p_se + s_se; |
| 54 | + const float p_scale = p_se / out_se; |
| 55 | + const float s_scale = s_se / out_se; |
| 56 | + |
| 57 | + if (pack_offset < head_size) { |
| 58 | + // Pack 128b load |
| 59 | + pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>( |
| 60 | + prefix_head_ptr)[pack_offset / pack_size]; |
| 61 | + pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>( |
| 62 | + suffix_head_ptr)[pack_offset / pack_size]; |
| 63 | + pack_128b_t o_out_pack; |
| 64 | + |
| 65 | +#pragma unroll |
| 66 | + for (uint i = 0; i < pack_size; ++i) { |
| 67 | + // Always use float for FMA to keep high precision. |
| 68 | + // half(uint16_t), bfloat16, float -> float. |
| 69 | + const float p_out_f = |
| 70 | + vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]); |
| 71 | + const float s_out_f = |
| 72 | + vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]); |
| 73 | + // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) |
| 74 | + const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); |
| 75 | + // float -> half(uint16_t), bfloat16, float. |
| 76 | + vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f); |
| 77 | + } |
| 78 | + |
| 79 | + // Pack 128b storage |
| 80 | + reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = |
| 81 | + o_out_pack; |
| 82 | + } |
| 83 | + // We only need to write to output_lse once per head. |
| 84 | + if (output_lse != nullptr && pack_idx == 0) { |
| 85 | + float out_lse = logf(out_se) + max_lse; |
| 86 | + output_lse[head_idx * num_tokens + token_idx] = out_lse; |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +} // namespace vllm |
| 91 | + |
| 92 | +// The following macro is used to dispatch the conversion function based on |
| 93 | +// the output data type. The FN is a macro that calls a function with |
| 94 | +// template<typename scalar_t>. |
| 95 | +#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ |
| 96 | + { \ |
| 97 | + if (scalar_dtype == at::ScalarType::Float) { \ |
| 98 | + fn(float); \ |
| 99 | + } else if (scalar_dtype == at::ScalarType::Half) { \ |
| 100 | + fn(uint16_t); \ |
| 101 | + } else if (scalar_dtype == at::ScalarType::BFloat16) { \ |
| 102 | + fn(__nv_bfloat16); \ |
| 103 | + } else { \ |
| 104 | + TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ |
| 105 | + } \ |
| 106 | + } |
| 107 | + |
| 108 | +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ |
| 109 | + { \ |
| 110 | + vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \ |
| 111 | + reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \ |
| 112 | + reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \ |
| 113 | + reinterpret_cast<float*>(prefix_lse.data_ptr()), \ |
| 114 | + reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \ |
| 115 | + reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \ |
| 116 | + num_heads, head_size); \ |
| 117 | + } |
| 118 | + |
| 119 | +/*@brief Merges the attention states from prefix and suffix |
| 120 | + * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d |
| 121 | + * |
| 122 | + * @param output [n,h,d] The output tensor to store the merged attention states. |
| 123 | + * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. |
| 124 | + * @param prefix_output [n,h,d] The prefix attention states. |
| 125 | + * @param prefix_lse [h,d] The log-sum-exp values for the prefix attention |
| 126 | + * states. |
| 127 | + * @param suffix_output [n,h,d] The suffix attention states. |
| 128 | + * @param suffix_lse [h,d] The log-sum-exp values for the suffix attention |
| 129 | + * states. |
| 130 | + */ |
| 131 | +template <typename scalar_t> |
| 132 | +void merge_attn_states_launcher(torch::Tensor& output, |
| 133 | + std::optional<torch::Tensor> output_lse, |
| 134 | + const torch::Tensor& prefix_output, |
| 135 | + const torch::Tensor& prefix_lse, |
| 136 | + const torch::Tensor& suffix_output, |
| 137 | + const torch::Tensor& suffix_lse) { |
| 138 | + constexpr uint NUM_THREADS = 128; |
| 139 | + const uint num_tokens = output.size(0); |
| 140 | + const uint num_heads = output.size(1); |
| 141 | + const uint head_size = output.size(2); |
| 142 | + const uint pack_size = 16 / sizeof(scalar_t); |
| 143 | + TORCH_CHECK(head_size % pack_size == 0, |
| 144 | + "headsize must be multiple of pack_size:", pack_size); |
| 145 | + float* output_lse_ptr = nullptr; |
| 146 | + if (output_lse.has_value()) { |
| 147 | + output_lse_ptr = output_lse.value().data_ptr<float>(); |
| 148 | + } |
| 149 | + // process one pack elements per thread. float -> 4, half/bf16 -> 8 |
| 150 | + const uint threads_per_head = head_size / pack_size; |
| 151 | + const uint total_threads = num_tokens * num_heads * threads_per_head; |
| 152 | + |
| 153 | + dim3 block(NUM_THREADS); |
| 154 | + dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); |
| 155 | + |
| 156 | + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); |
| 157 | +} |
| 158 | + |
| 159 | +#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ |
| 160 | + { \ |
| 161 | + merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \ |
| 162 | + prefix_lse, suffix_output, \ |
| 163 | + suffix_lse); \ |
| 164 | + } |
| 165 | + |
| 166 | +void merge_attn_states(torch::Tensor& output, |
| 167 | + std::optional<torch::Tensor> output_lse, |
| 168 | + const torch::Tensor& prefix_output, |
| 169 | + const torch::Tensor& prefix_lse, |
| 170 | + const torch::Tensor& suffix_output, |
| 171 | + const torch::Tensor& suffix_lse) { |
| 172 | + DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); |
| 173 | +} |
0 commit comments