Skip to content

Commit a66eeab

Browse files
bobboliTom-Zheng
andauthored
[TRTLLM-9805][feat] Skip Softmax Attention. (NVIDIA#9821)
Signed-off-by: Bo Li <[email protected]> Signed-off-by: Tian Zheng <[email protected]> Co-authored-by: Tian Zheng <[email protected]>
1 parent dcd3f7b commit a66eeab

File tree

2,967 files changed

+10152
-5178
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

2,967 files changed

+10152
-5178
lines changed

cpp/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
6868
ON)
6969
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
7070
"Using open sourced Cutlass AR gemm kernel" ON)
71+
option(SKIP_SOFTMAX_STAT "Enable Statistics of Skip-Softmax" OFF)
7172

7273
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
7374

@@ -360,6 +361,11 @@ else()
360361
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
361362
endif()
362363

364+
if(SKIP_SOFTMAX_STAT)
365+
add_compile_definitions("SKIP_SOFTMAX_STAT")
366+
message(STATUS "SKIP_SOFTMAX_STAT is enabled")
367+
endif()
368+
363369
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
364370
# be found in
365371
# https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html#index-mcmodel_003dmedium-1

cpp/kernels/fmha_v2/Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ PREPROCESSOR_FLAGS += -DUSE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE
6969
# Do we want to use half accumulation for flash attention
7070
PREPROCESSOR_FLAGS += -DHALF_ACCUMULATION_FOR_FLASH_ATTENTION
7171

72+
# Print the resulted sparsity given threshold in Skip-Softmax attention
73+
# Note: You only need to "python scripts/build_wheel.py -D SKIP_SOFTMAX_STAT=ON ..." to use it inside TRTLLM.
74+
# Turn this on manually only if you want to build&run the unittest (bin/fmha.exe) with SKIP_SOFTMAX_STAT.
75+
# PREPROCESSOR_FLAGS += -DSKIP_SOFTMAX_STAT
76+
7277
# Add FLAGS when generating cubins.
7378
ifdef GENERATE_CUBIN
7479
PREPROCESSOR_FLAGS += -DGENERATE_CUBIN

cpp/kernels/fmha_v2/setup.py

Lines changed: 85 additions & 33 deletions
Large diffs are not rendered by default.

cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ struct Compute
256256
actual_kv_seqlen, alibi_head_scale, \
257257
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
258258
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
259-
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, kv_step_idx == kv_idx_end - 1);
259+
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \
260+
&shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1);
260261

261262
////////////////////////////////////////////////////////////////////////////////////////////////
262263

@@ -360,6 +361,12 @@ struct Compute
360361
// Contiguous QKV FMHA assumes q, and kv have the same sequence length.
361362
int const actual_kv_seqlen = SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;
362363

364+
// Update threshold of Skip-Softmax
365+
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
366+
{
367+
softmax.skip_softmax_threshold = params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
368+
}
369+
363370
// Calculate the alibi head_scaling_factor.
364371
float alibi_head_scale
365372
= APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(head_info.bidh, params.alibi_params) : 0.f;
@@ -513,6 +520,13 @@ struct Compute
513520
}
514521
}
515522
}
523+
#ifdef SKIP_SOFTMAX_STAT
524+
if (tidx == 0)
525+
{
526+
atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks);
527+
atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks);
528+
}
529+
#endif
516530
}
517531

518532
////////////////////////////////////////////////////////////////////////////////////////////////
@@ -522,8 +536,15 @@ struct Compute
522536
Compute_tile_o& ctile_o, float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M],
523537
int const tidx, int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
524538
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, Circular_buffer_kv_reader& cbr_v,
525-
OrderedMutexAccessor& mutex, bool complete = false)
539+
OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote, bool complete = false)
526540
{
541+
542+
// Skip-softmax vote initialization
543+
if (tidx == 0)
544+
{
545+
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before voting.
546+
*skip_softmax_vote = 1;
547+
}
527548
// load the scales of K/V from global memory
528549
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
529550
if constexpr (block_size > 0) \
@@ -557,6 +578,10 @@ struct Compute
557578
// Ctile_p is only used once by each n step.
558579
ctile_p.clear();
559580

581+
// If skip_softmax is enabled, make sure there is no racing between the initialization and writing of
582+
// skip_softmax_vote.
583+
named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128);
584+
560585
// BMM1 (Q x K').
561586
warpgroup_arrive();
562587

@@ -626,8 +651,22 @@ struct Compute
626651
softmax.apply_alibi_and_mask<APPLY_MASK>(
627652
ctile_p, params.alibi_params, alibi_head_scale, actual_kv_seqlen, row_offset, col_offset);
628653

629-
// Softmax Exp, max/sum, and update scales.
630-
softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum);
654+
// Softmax Exp, max/sum, and update scales. If returns false we skip the rest.
655+
if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote))
656+
{
657+
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1)
658+
{
659+
// Notify another warpgroup to execute QGMMA.
660+
mutex.named_bar_arrive();
661+
}
662+
// Need to wait V, otherwise compute-sanitizer synccheck will fail.
663+
int ready2 = cbr_v.peek();
664+
if (!ready2)
665+
{
666+
cbr_v.wait();
667+
}
668+
return;
669+
}
631670

632671
// experiments show that here is the best place to load scales of V
633672
float scales_v[SAGE_BLOCKS_PER_STEP_V];

cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h

Lines changed: 138 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
#pragma once
1919

20+
#include "fmha/hopper/arrive_wait.h"
21+
2022
#include <fmha/softmax.h>
2123
#include <fmha/traits.h>
2224
#include <fmha/utils.h>
@@ -104,6 +106,12 @@ struct Softmax_base
104106
CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK
105107
};
106108

109+
// There are 2 warpgroups so 0x3 and 0x4 are used
110+
enum
111+
{
112+
SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID
113+
};
114+
107115
// Ctor.
108116
template <typename Params>
109117
inline __device__ Softmax_base(Params params, int tidx)
@@ -114,6 +122,11 @@ struct Softmax_base
114122
, log2_chunked_attention_size_(params.log2_chunked_attention_size)
115123
, packed_mask_ptr_{reinterpret_cast<uint32_t*>(params.packed_mask_ptr)}
116124
, params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes}
125+
#ifdef SKIP_SOFTMAX_STAT
126+
, total_blocks(0)
127+
, skipped_blocks(0)
128+
#endif
129+
, skip_softmax_threshold(0)
117130
{
118131

119132
int warp = tidx / 32;
@@ -330,31 +343,79 @@ struct Softmax_base
330343
}
331344

332345
// Calculate max/sum, and update flash-attention scales.
346+
// Returns false if skipped due to skip-softmax attention feature.
333347
template <bool IS_FIRST_COL>
334-
inline __device__ void compute_and_update_scale(
335-
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
348+
inline __device__ bool compute_and_update_scale(
349+
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
336350
{
337351
float const scale = reinterpret_cast<float const&>(scale_bmm1_);
338352

353+
// whether this warpgroup skips the softmax
354+
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
355+
bool skip = may_skip;
356+
339357
// Row-wise max of current tile.
340358
#pragma unroll
341359
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
342360
{
343-
if (IS_FIRST_COL)
344-
{
345-
local_max_[mi] = elt_[mi][0];
346-
}
347-
else
348-
{
349-
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
350-
}
361+
local_max_[mi] = elt_[mi][0];
351362
#pragma unroll
352363
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
353364
{
354365
local_max_[mi] = fmaxf(local_max_[mi], elt_[mi][ni]);
355366
}
356367
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
357368
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
369+
370+
if constexpr (may_skip)
371+
{
372+
// AND(&) the CORES_M results, then `skip` means whether to skip
373+
// the CORES_M(=2) rows
374+
if constexpr (!EXP2F_OPTIMIZATION)
375+
{
376+
skip &= expf(local_max_[mi] - global_max[mi]) < skip_softmax_threshold;
377+
}
378+
else
379+
{
380+
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold;
381+
}
382+
}
383+
384+
if (!IS_FIRST_COL)
385+
{
386+
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
387+
}
388+
}
389+
390+
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
391+
{
392+
#ifdef SKIP_SOFTMAX_STAT
393+
total_blocks++;
394+
#endif
395+
if constexpr (may_skip)
396+
{
397+
398+
// AND(&) the results together in a warp, then `skip` means whether to skip
399+
// all the 16 rows managed by this warp.
400+
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
401+
// instead of 0xffffffff. But the perf is the same.
402+
skip = __all_sync(0xffffffff, skip);
403+
if (threadIdx.x % 32 == 0)
404+
{
405+
// The leader of each warp votes.
406+
atomicAnd(skip_softmax_vote, uint32_t(skip));
407+
}
408+
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
409+
named_barrier_wait(SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
410+
skip = *((uint32_t volatile*) skip_softmax_vote);
411+
if (skip)
412+
{
413+
#ifdef SKIP_SOFTMAX_STAT
414+
skipped_blocks++;
415+
#endif
416+
return false;
417+
}
418+
}
358419
}
359420

360421
// Softmax Exp.
@@ -436,6 +497,7 @@ struct Softmax_base
436497
global_max[mi] = max_new;
437498
}
438499
}
500+
return true;
439501
}
440502

441503
// Update flash attention scales and pack elements for BMM2.
@@ -513,6 +575,13 @@ struct Softmax_base
513575
float correction_[Mma_tile_p::CORES_M];
514576
// The packed mask.
515577
uint4 packed_mask_;
578+
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold.
579+
float skip_softmax_threshold;
580+
#ifdef SKIP_SOFTMAX_STAT
581+
// Statistics of skip-softmax
582+
uint32_t total_blocks;
583+
uint32_t skipped_blocks;
584+
#endif
516585
};
517586

518587
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -868,35 +937,83 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
868937
}
869938

870939
// Calculate max/sum, and update flash-attention scales.
940+
// Returns false if skipped due to skip-softmax attention feature.
871941
template <bool IS_FIRST_COL>
872-
inline __device__ void compute_and_update_scale(
873-
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
942+
inline __device__ bool compute_and_update_scale(
943+
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
874944
{
875945
float const scale = reinterpret_cast<float const&>(this->scale_bmm1_);
876946
float(&local_max_)[Mma_tile_p::CORES_M] = this->local_max_;
877947
float(&local_sum_)[Mma_tile_p::CORES_M] = this->local_sum_;
878948
float(&correction_)[Mma_tile_p::CORES_M] = this->correction_;
879949
float(&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2] = this->elt_;
880950

951+
// whether this warpgroup skips the softmax
952+
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
953+
bool skip = may_skip;
954+
881955
// Row-wise max of current tile.
882956
#pragma unroll
883957
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
884958
{
885-
if (IS_FIRST_COL)
886-
{
887-
local_max_[mi] = elt_[mi][0];
888-
}
889-
else
890-
{
891-
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
892-
}
959+
local_max_[mi] = elt_[mi][0];
893960
#pragma unroll
894961
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
895962
{
896963
local_max_[mi] = fmaxf(local_max_[mi], elt_[mi][ni]);
897964
}
898965
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
899966
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
967+
// AND(&) the CORES_M results, then `skip` means whether to skip
968+
// the CORES_M(=2) rows
969+
if constexpr (may_skip)
970+
{
971+
// AND(&) the CORES_M results, then `skip` means whether to skip
972+
// the CORES_M(=2) rows
973+
if constexpr (!EXP2F_OPTIMIZATION)
974+
{
975+
skip &= expf(local_max_[mi] - global_max[mi]) < this->skip_softmax_threshold;
976+
}
977+
else
978+
{
979+
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < this->skip_softmax_threshold;
980+
}
981+
}
982+
if (!IS_FIRST_COL)
983+
{
984+
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
985+
}
986+
}
987+
988+
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
989+
{
990+
#ifdef SKIP_SOFTMAX_STAT
991+
this->total_blocks++;
992+
#endif
993+
994+
if constexpr (may_skip)
995+
{
996+
// AND(&) the results together in a warp, then `skip` means whether to skip
997+
// all the 16 rows managed by this warp.
998+
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
999+
// instead of 0xffffffff. But the perf is the same.
1000+
skip = __all_sync(0xffffffff, skip);
1001+
if (threadIdx.x % 32 == 0)
1002+
{
1003+
// The leader of each warp votes.
1004+
atomicAnd(skip_softmax_vote, uint32_t(skip));
1005+
}
1006+
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
1007+
named_barrier_wait(Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
1008+
skip = *((uint32_t volatile*) skip_softmax_vote);
1009+
if (skip)
1010+
{
1011+
#ifdef SKIP_SOFTMAX_STAT
1012+
this->skipped_blocks++;
1013+
#endif
1014+
return false;
1015+
}
1016+
}
9001017
}
9011018

9021019
// Softmax Exp.
@@ -987,6 +1104,7 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
9871104
global_max[mi] = max_new;
9881105
}
9891106
}
1107+
return true;
9901108
}
9911109

9921110
// Update flash attention scales and pack elements for BMM2.

0 commit comments

Comments
 (0)