From af75b30af1c70abe2027747cb0a5c0ad3a3fdfaa Mon Sep 17 00:00:00 2001 From: stevhliu Date: Mon, 8 Jun 2026 12:23:18 -0700 Subject: [PATCH 1/2] docs --- docs/source/en/attention_interface.md | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md index aee2c2591098..f4de10c89a45 100644 --- a/docs/source/en/attention_interface.md +++ b/docs/source/en/attention_interface.md @@ -307,6 +307,57 @@ attention_mask = create_causal_mask(**mask_kwargs) During generation, [`~GenerationMixin.generate`] builds masks through [`create_masks_for_generate`], which dispatches to the right `create_*_mask` based on the model config. Override it on a model class to plug in a custom masking strategy for generation. +## Pass a custom 4D attention mask + +Pass your own 4D mask when you need an attention pattern the `create_*_mask` functions can't express. A 4D mask has the shape `(batch_size, 1, query_length, kv_length)`, which Transformers detects and uses the mask as-is, skipping `create_*_mask`. + +A 4D mask must use one of two value conventions. + +| dtype | attend | mask out | +|---|---|---| +| boolean | `True` | `False` | +| float | `0.0` | `-inf` | + +The float convention adds the mask to the attention scores before the softmax. A score plus `0.0` stays unchanged, so the position contributes. A score plus `-inf` drops to zero after the softmax, so the position is excluded. + +A common mistake is to reuse the `1`/`0` convention of a 2D padding mask in a float 4D mask. Because the mask is added to the scores, `0.0` keeps a position and `1.0` only adds a small bias. + +The example below contrasts a wrong mask with a correct one. Both start from the same `1`/`0` causal pattern. + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + +input_ids = tokenizer("my favorite condiment on a", return_tensors="pt").input_ids +seq_len = input_ids.shape[1] + +# 1 attends, 0 masks +causal = torch.tril(torch.ones(seq_len, seq_len)) + +# wrong: 1.0/0.0 floats are added to the scores, so 0.0 keeps a token and 1.0 barely changes it +wrong_mask = causal[None, None] + +# correct: 0.0 attends, -inf masks +correct_mask = torch.where(causal.bool(), 0.0, float("-inf"))[None, None] +``` + +The wrong mask keeps every position because `0.0` is the value that masks, so it never excludes anything. + +```text + wrong_mask correct_mask + (1 attends, 0 masks) (0 attends, -inf masks) + + k0 k1 k2 k3 k4 k0 k1 k2 k3 k4 + q0 1 0 0 0 0 q0 0 -inf -inf -inf -inf + q1 1 1 0 0 0 q1 0 0 -inf -inf -inf + q2 1 1 1 0 0 q2 0 0 0 -inf -inf + q3 1 1 1 1 0 q3 0 0 0 0 -inf + q4 1 1 1 1 1 q4 0 0 0 0 0 +``` + ## Bidirectional attention Decoder-only models use causal (unidirectional) attention by default, where each token only attends to itself and previous tokens. Set `is_causal=False` to switch to bidirectional attention, where every token attends to every other token. This lets you use decoder-only models as text encoders, for example, to generate embeddings. From cfb7b7b92df3381378bbc6a2b2cc24c3bd7ed8e8 Mon Sep 17 00:00:00 2001 From: stevhliu Date: Wed, 10 Jun 2026 08:46:40 -0700 Subject: [PATCH 2/2] feedback --- docs/source/en/attention_interface.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md index f4de10c89a45..af422a99f528 100644 --- a/docs/source/en/attention_interface.md +++ b/docs/source/en/attention_interface.md @@ -309,9 +309,9 @@ During generation, [`~GenerationMixin.generate`] builds masks through [`create_m ## Pass a custom 4D attention mask -Pass your own 4D mask when you need an attention pattern the `create_*_mask` functions can't express. A 4D mask has the shape `(batch_size, 1, query_length, kv_length)`, which Transformers detects and uses the mask as-is, skipping `create_*_mask`. +Pass your own 4D mask when you need an attention pattern the `create_*_mask` functions can't express. A 4D mask has the shape `(batch_size, 1, query_length, kv_length)`, where `1` broadcasts the same mask across every attention head. Transformers detects and uses the mask as-is, skipping `create_*_mask`. -A 4D mask must use one of two value conventions. +A 4D mask uses one of two value conventions. | dtype | attend | mask out | |---|---|---| @@ -320,6 +320,9 @@ A 4D mask must use one of two value conventions. The float convention adds the mask to the attention scores before the softmax. A score plus `0.0` stays unchanged, so the position contributes. A score plus `-inf` drops to zero after the softmax, so the position is excluded. +> [!IMPORTANT] +> The accepted convention depends on the attention backend. `sdpa` takes a boolean or a float mask. `eager` adds the mask to the scores, so it takes a float mask only. `flash_attention_2` and `flex_attention` consume their own formats (a 2D padding mask and a [BlockMask](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#torch.nn.attention.flex_attention.BlockMask)) and don't accept a raw 4D mask. + A common mistake is to reuse the `1`/`0` convention of a 2D padding mask in a float 4D mask. Because the mask is added to the scores, `0.0` keeps a position and `1.0` only adds a small bias. The example below contrasts a wrong mask with a correct one. Both start from the same `1`/`0` causal pattern. @@ -328,7 +331,7 @@ The example below contrasts a wrong mask with a correct one. Both start from the import torch from transformers import AutoModelForCausalLM, AutoTokenizer -model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", attn_implementation="sdpa") tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") input_ids = tokenizer("my favorite condiment on a", return_tensors="pt").input_ids