Skip to content

[New Feature] fa3 支持flash mask #3184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 5, 2025

Conversation

yangjianfengo1
Copy link
Contributor

@yangjianfengo1 yangjianfengo1 commented Aug 4, 2025

在flash attention 3的基础上支持了mask形式,其中mask是一个shape为[q_seq_len]的int 数组,对于第i个token,那么qkgemm矩阵的第i行,第mask[i]列之后的数字都会被mask 掉,即qkgemm[i,mask[i]:] 会被mask掉,例如要使用casual mask的话,mask的数组就为[1,2,3,4,5,.......]

若mask传入None,默认采用casual mask

Copy link

paddle-bot bot commented Aug 4, 2025

Thanks for your contribution!

naive_attn_out = naive_attn(q_input, k_input, v_input, mask)
paddle_attn_out = paddle_flash_attn_mask(q_input, k_input, v_input, mask)

print((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用assert判断吧,CI监控起来单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@yuanlehome yuanlehome merged commit 40f7f3e into PaddlePaddle:develop Aug 5, 2025
12 of 14 checks passed
@yangjianfengo1 yangjianfengo1 deleted the fa3 branch August 6, 2025 04:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants