Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/source/en/attention_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,60 @@ attention_mask = create_causal_mask(**mask_kwargs)

During generation, [`~GenerationMixin.generate`] builds masks through [`create_masks_for_generate`], which dispatches to the right `create_*_mask` based on the model config. Override it on a model class to plug in a custom masking strategy for generation.

## Pass a custom 4D attention mask

Pass your own 4D mask when you need an attention pattern the `create_*_mask` functions can't express. A 4D mask has the shape `(batch_size, 1, query_length, kv_length)`, where `1` broadcasts the same mask across every attention head. Transformers detects and uses the mask as-is, skipping `create_*_mask`.

A 4D mask uses one of two value conventions.

| dtype | attend | mask out |
|---|---|---|
| boolean | `True` | `False` |
| float | `0.0` | `-inf` |

The float convention adds the mask to the attention scores before the softmax. A score plus `0.0` stays unchanged, so the position contributes. A score plus `-inf` drops to zero after the softmax, so the position is excluded.

> [!IMPORTANT]
> The accepted convention depends on the attention backend. `sdpa` takes a boolean or a float mask. `eager` adds the mask to the scores, so it takes a float mask only. `flash_attention_2` and `flex_attention` consume their own formats (a 2D padding mask and a [BlockMask](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#torch.nn.attention.flex_attention.BlockMask)) and don't accept a raw 4D mask.

A common mistake is to reuse the `1`/`0` convention of a 2D padding mask in a float 4D mask. Because the mask is added to the scores, `0.0` keeps a position and `1.0` only adds a small bias.

The example below contrasts a wrong mask with a correct one. Both start from the same `1`/`0` causal pattern.

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", attn_implementation="sdpa")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

input_ids = tokenizer("my favorite condiment on a", return_tensors="pt").input_ids
seq_len = input_ids.shape[1]

# 1 attends, 0 masks
causal = torch.tril(torch.ones(seq_len, seq_len))

# wrong: 1.0/0.0 floats are added to the scores, so 0.0 keeps a token and 1.0 barely changes it
wrong_mask = causal[None, None]

# correct: 0.0 attends, -inf masks
correct_mask = torch.where(causal.bool(), 0.0, float("-inf"))[None, None]
```

The wrong mask keeps every position because `0.0` is the value that masks, so it never excludes anything.

```text
wrong_mask correct_mask
(1 attends, 0 masks) (0 attends, -inf masks)

k0 k1 k2 k3 k4 k0 k1 k2 k3 k4
q0 1 0 0 0 0 q0 0 -inf -inf -inf -inf
q1 1 1 0 0 0 q1 0 0 -inf -inf -inf
q2 1 1 1 0 0 q2 0 0 0 -inf -inf
q3 1 1 1 1 0 q3 0 0 0 0 -inf
q4 1 1 1 1 1 q4 0 0 0 0 0
```

## Bidirectional attention

Decoder-only models use causal (unidirectional) attention by default, where each token only attends to itself and previous tokens. Set `is_causal=False` to switch to bidirectional attention, where every token attends to every other token. This lets you use decoder-only models as text encoders, for example, to generate embeddings.
Expand Down
Loading