1717
1818#pragma once
1919
20+ #include " fmha/hopper/arrive_wait.h"
21+
2022#include < fmha/softmax.h>
2123#include < fmha/traits.h>
2224#include < fmha/utils.h>
@@ -104,6 +106,12 @@ struct Softmax_base
104106 CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK
105107 };
106108
109+ // There are 2 warpgroups so 0x3 and 0x4 are used
110+ enum
111+ {
112+ SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID
113+ };
114+
107115 // Ctor.
108116 template <typename Params>
109117 inline __device__ Softmax_base (Params params, int tidx)
@@ -114,6 +122,11 @@ struct Softmax_base
114122 , log2_chunked_attention_size_(params.log2_chunked_attention_size)
115123 , packed_mask_ptr_{reinterpret_cast <uint32_t *>(params.packed_mask_ptr )}
116124 , params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes }
125+ #ifdef SKIP_SOFTMAX_STAT
126+ , total_blocks(0 )
127+ , skipped_blocks(0 )
128+ #endif
129+ , skip_softmax_threshold(0 )
117130 {
118131
119132 int warp = tidx / 32 ;
@@ -330,31 +343,79 @@ struct Softmax_base
330343 }
331344
332345 // Calculate max/sum, and update flash-attention scales.
346+ // Returns false if skipped due to skip-softmax attention feature.
333347 template <bool IS_FIRST_COL>
334- inline __device__ void compute_and_update_scale (
335- float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
348+ inline __device__ bool compute_and_update_scale (
349+ float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote )
336350 {
337351 float const scale = reinterpret_cast <float const &>(scale_bmm1_);
338352
353+ // whether this warpgroup skips the softmax
354+ constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
355+ bool skip = may_skip;
356+
339357// Row-wise max of current tile.
340358#pragma unroll
341359 for (int mi = 0 ; mi < Mma_tile_p::CORES_M; mi++)
342360 {
343- if (IS_FIRST_COL)
344- {
345- local_max_[mi] = elt_[mi][0 ];
346- }
347- else
348- {
349- local_max_[mi] = fmaxf (global_max[mi], elt_[mi][0 ]);
350- }
361+ local_max_[mi] = elt_[mi][0 ];
351362#pragma unroll
352363 for (int ni = 1 ; ni < Mma_tile_p::CORES_N * 2 ; ni++)
353364 {
354365 local_max_[mi] = fmaxf (local_max_[mi], elt_[mi][ni]);
355366 }
356367 local_max_[mi] = fmaxf (__shfl_xor_sync (uint32_t (-1 ), local_max_[mi], 1 ), local_max_[mi]);
357368 local_max_[mi] = fmaxf (__shfl_xor_sync (uint32_t (-1 ), local_max_[mi], 2 ), local_max_[mi]);
369+
370+ if constexpr (may_skip)
371+ {
372+ // AND(&) the CORES_M results, then `skip` means whether to skip
373+ // the CORES_M(=2) rows
374+ if constexpr (!EXP2F_OPTIMIZATION)
375+ {
376+ skip &= expf (local_max_[mi] - global_max[mi]) < skip_softmax_threshold;
377+ }
378+ else
379+ {
380+ skip &= exp2f ((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold;
381+ }
382+ }
383+
384+ if (!IS_FIRST_COL)
385+ {
386+ local_max_[mi] = fmaxf (local_max_[mi], global_max[mi]);
387+ }
388+ }
389+
390+ if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
391+ {
392+ #ifdef SKIP_SOFTMAX_STAT
393+ total_blocks++;
394+ #endif
395+ if constexpr (may_skip)
396+ {
397+
398+ // AND(&) the results together in a warp, then `skip` means whether to skip
399+ // all the 16 rows managed by this warp.
400+ // each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
401+ // instead of 0xffffffff. But the perf is the same.
402+ skip = __all_sync (0xffffffff , skip);
403+ if (threadIdx.x % 32 == 0 )
404+ {
405+ // The leader of each warp votes.
406+ atomicAnd (skip_softmax_vote, uint32_t (skip));
407+ }
408+ // WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
409+ named_barrier_wait (SKIP_SOFTMAX_BARRIER + threadIdx.x / 128 , 128 );
410+ skip = *((uint32_t volatile *) skip_softmax_vote);
411+ if (skip)
412+ {
413+ #ifdef SKIP_SOFTMAX_STAT
414+ skipped_blocks++;
415+ #endif
416+ return false ;
417+ }
418+ }
358419 }
359420
360421// Softmax Exp.
@@ -436,6 +497,7 @@ struct Softmax_base
436497 global_max[mi] = max_new;
437498 }
438499 }
500+ return true ;
439501 }
440502
441503 // Update flash attention scales and pack elements for BMM2.
@@ -513,6 +575,13 @@ struct Softmax_base
513575 float correction_[Mma_tile_p::CORES_M];
514576 // The packed mask.
515577 uint4 packed_mask_;
578+ // Skip softmax when exp(local_max - global_max) < skip_softmax_threshold.
579+ float skip_softmax_threshold;
580+ #ifdef SKIP_SOFTMAX_STAT
581+ // Statistics of skip-softmax
582+ uint32_t total_blocks;
583+ uint32_t skipped_blocks;
584+ #endif
516585};
517586
518587// //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -868,35 +937,83 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
868937 }
869938
870939 // Calculate max/sum, and update flash-attention scales.
940+ // Returns false if skipped due to skip-softmax attention feature.
871941 template <bool IS_FIRST_COL>
872- inline __device__ void compute_and_update_scale (
873- float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
942+ inline __device__ bool compute_and_update_scale (
943+ float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote )
874944 {
875945 float const scale = reinterpret_cast <float const &>(this ->scale_bmm1_ );
876946 float (&local_max_)[Mma_tile_p::CORES_M] = this ->local_max_ ;
877947 float (&local_sum_)[Mma_tile_p::CORES_M] = this ->local_sum_ ;
878948 float (&correction_)[Mma_tile_p::CORES_M] = this ->correction_ ;
879949 float (&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2 ] = this ->elt_ ;
880950
951+ // whether this warpgroup skips the softmax
952+ constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
953+ bool skip = may_skip;
954+
881955// Row-wise max of current tile.
882956#pragma unroll
883957 for (int mi = 0 ; mi < Mma_tile_p::CORES_M; mi++)
884958 {
885- if (IS_FIRST_COL)
886- {
887- local_max_[mi] = elt_[mi][0 ];
888- }
889- else
890- {
891- local_max_[mi] = fmaxf (global_max[mi], elt_[mi][0 ]);
892- }
959+ local_max_[mi] = elt_[mi][0 ];
893960#pragma unroll
894961 for (int ni = 1 ; ni < Mma_tile_p::CORES_N * 2 ; ni++)
895962 {
896963 local_max_[mi] = fmaxf (local_max_[mi], elt_[mi][ni]);
897964 }
898965 local_max_[mi] = fmaxf (__shfl_xor_sync (uint32_t (-1 ), local_max_[mi], 1 ), local_max_[mi]);
899966 local_max_[mi] = fmaxf (__shfl_xor_sync (uint32_t (-1 ), local_max_[mi], 2 ), local_max_[mi]);
967+ // AND(&) the CORES_M results, then `skip` means whether to skip
968+ // the CORES_M(=2) rows
969+ if constexpr (may_skip)
970+ {
971+ // AND(&) the CORES_M results, then `skip` means whether to skip
972+ // the CORES_M(=2) rows
973+ if constexpr (!EXP2F_OPTIMIZATION)
974+ {
975+ skip &= expf (local_max_[mi] - global_max[mi]) < this ->skip_softmax_threshold ;
976+ }
977+ else
978+ {
979+ skip &= exp2f ((local_max_[mi] - global_max[mi]) * scale) < this ->skip_softmax_threshold ;
980+ }
981+ }
982+ if (!IS_FIRST_COL)
983+ {
984+ local_max_[mi] = fmaxf (local_max_[mi], global_max[mi]);
985+ }
986+ }
987+
988+ if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
989+ {
990+ #ifdef SKIP_SOFTMAX_STAT
991+ this ->total_blocks ++;
992+ #endif
993+
994+ if constexpr (may_skip)
995+ {
996+ // AND(&) the results together in a warp, then `skip` means whether to skip
997+ // all the 16 rows managed by this warp.
998+ // each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
999+ // instead of 0xffffffff. But the perf is the same.
1000+ skip = __all_sync (0xffffffff , skip);
1001+ if (threadIdx.x % 32 == 0 )
1002+ {
1003+ // The leader of each warp votes.
1004+ atomicAnd (skip_softmax_vote, uint32_t (skip));
1005+ }
1006+ // WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
1007+ named_barrier_wait (Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128 , 128 );
1008+ skip = *((uint32_t volatile *) skip_softmax_vote);
1009+ if (skip)
1010+ {
1011+ #ifdef SKIP_SOFTMAX_STAT
1012+ this ->skipped_blocks ++;
1013+ #endif
1014+ return false ;
1015+ }
1016+ }
9001017 }
9011018
9021019// Softmax Exp.
@@ -987,6 +1104,7 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
9871104 global_max[mi] = max_new;
9881105 }
9891106 }
1107+ return true ;
9901108 }
9911109
9921110 // Update flash attention scales and pack elements for BMM2.
0 commit comments