Skip to content

[CUDA] Support array mask in SDPA#2822

Merged
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:cudnn-sdpa-masks
Nov 26, 2025
Merged

[CUDA] Support array mask in SDPA#2822
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:cudnn-sdpa-masks

Conversation

@zcbenz
Copy link
Collaborator

@zcbenz zcbenz commented Nov 23, 2025

Note that cuDNN does not support boolean masks, so we have to convert boolean masks to additive masks with where(mask, full_like(mask, 0), full_like(mask, -inf)), which has some performance penalty. (PyTorch does the same thing too.)

What cuDNN does support is setting padding mask directly: we pass the sequence lengths and cuDNN will apply padding masks automatically, and it works together with the set_causal_mask flag. I don't know how much performance gain this approach brings, but I think it worths a try as a future work.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

As a general guideline we should work towards avoiding in dispatching differently based on the back-end because it breaks the ability to exporting from machine to another (e.g. export on cuda would not work on Metal). The fast primitives are one place where we break this guideline a lot (e.g. cpu vs gpu). Just elaborating for future reference as we may want to push the mask -> float into the primitive.

@zcbenz zcbenz merged commit 704fd1a into ml-explore:main Nov 26, 2025
10 checks passed
@zcbenz zcbenz deleted the cudnn-sdpa-masks branch November 26, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants