Skip to content

Conversation

yaoyaoding
Copy link
Member

@yaoyaoding yaoyaoding commented Sep 1, 2025

This PR adds the implementation for the decoding kernel of flash linear attention:
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py

      name      (B, T, H, K)   (HV, V) latency (ms)
0   triton    (1, 1, 4, 128)  (8, 128)        0.007
1    tilus    (1, 1, 4, 128)  (8, 128)        0.006
2   triton    (1, 2, 4, 128)  (8, 128)        0.008
3    tilus    (1, 2, 4, 128)  (8, 128)        0.007
4   triton    (1, 4, 4, 128)  (8, 128)        0.011
5    tilus    (1, 4, 4, 128)  (8, 128)        0.009
6   triton    (1, 8, 4, 128)  (8, 128)        0.016
7    tilus    (1, 8, 4, 128)  (8, 128)        0.013
8   triton   (1, 16, 4, 128)  (8, 128)        0.026
9    tilus   (1, 16, 4, 128)  (8, 128)        0.022
10  triton   (1, 32, 4, 128)  (8, 128)        0.047
11   tilus   (1, 32, 4, 128)  (8, 128)        0.044
12  triton   (1, 64, 4, 128)  (8, 128)        0.081
13   tilus   (1, 64, 4, 128)  (8, 128)        0.086
14  triton  (1, 128, 4, 128)  (8, 128)        0.153
15   tilus  (1, 128, 4, 128)  (8, 128)        0.161

To support this kernel, have some other enhancements:

  1. add sqrt operator
  2. move the layout inference verification to have better visualization in debug log

Signed-off-by: Yaoyao Ding <[email protected]>
@yaoyaoding yaoyaoding changed the title [Example] Add example implementation for decoding kernel of flash linear attention [Example] Add example for decoding kernel of flash linear attention Sep 1, 2025
@yaoyaoding yaoyaoding merged commit 87d6278 into main Sep 1, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant