Skip to content

Conversation

@Calvin-Xu
Copy link
Member

This draft PR adds "flash" implementations of recurrent and chunk gated delta rules, based on the Triton kernels from https://github.com/fla-org/flash-linear-attention. Notable improvements are:

  • 2D tiling on (num_heads, num_V_tiles) grid, with K dim tiled internally in a loop
  • fusing kernels, recomputing in backward, and other optimizations from the official FLA version
  • support for varlen (ragged inputs)

Existing tests were parameterized with @pytest.mark.parametrize("use_flash", [True, False]) and all pass on CPU pl.pallas_call(..., interpret=True). Additional work is in progress to 1. make them work correctly on TPUs (https://docs.jax.dev/en/latest/pallas/tpu/index.html) and 2. refactor and cleanup, in particular move the flash code to a separate file.

@Calvin-Xu Calvin-Xu self-assigned this Nov 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants