Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions docs/source/en/model_doc/diffusion_gemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ The encoder operates in a prefill capacity, processing the initial prompt and ge

During inference, DiffusionGemma leverages multi-canvas sampling. Rather than generating one token at a time, the model iteratively denoises a full block of tokens using a diffusion sampler. Once a canvas is fully denoised, it is processed by the encoder and appended to the KV cache, after which the model generates the next canvas. This block-autoregressive approach facilitates text generation at higher speeds.

You can find the model card and checkpoint [here](https://huggingface.co/google/diffusiongemma-26B-A4B-it).
You can find the model card and checkpoint [here](https://huggingface.co/google/diffusiongemma-26B-A4B-it). You can find a visual guide to the model [here](https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-diffusiongemma).

## Usage examples

Despite it being a text diffusion model and having a custom generation loop, most of the interface is shared with other model that can generate text with `.generate()`. If you're using another `transformers` model in your app, you should be able to directly replace it with this model.
Despite it being a text diffusion model and having a custom generation loop, most of the interface is shared with other models that can generate text with [`DiffusionGemmaGenerationMixin.generate`]. If you're using another `transformers` model in your app, you should be able to directly replace it with this model.

### Common caveats

Common caveats:
- DiffusionGemma doesn't accept `use_cache`. It always uses a KV cache;
- Support for common flags like `top_k` won't be available at release day, but will be added over time if they are compatible with text diffusion.

### Basic example

```python
from transformers import DiffusionGemmaForBlockDiffusion, AutoProcessor

Expand All @@ -62,6 +65,8 @@ inputs = processor.apply_chat_template(
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
# Add the following to enable thinking
# enable_thinking=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]

Expand All @@ -71,6 +76,8 @@ output = model.generate(**inputs, max_new_tokens=256)
print(processor.decode(output.sequences[0][input_len:], skip_special_tokens=True))
```

### Streaming

Like other models that can generate text, you can set a streamer class to stream text. Unlike other models, DiffusionGemma generates intermediate drafts before the final text. You can visualize them with `TextDiffusionStreamer`

```python
Expand All @@ -81,6 +88,15 @@ streamer = TextDiffusionStreamer(tokenizer=processor.tokenizer)
model.generate(**inputs, max_new_tokens=256, streamer=streamer)
```

### Setting a starting denoising output

The model is trained to iteratively refine blocks of 256 tokens. On some applications, it may be beneficial to provide a starting point for the decoder, rather than starting from random tokens. You can use the `decoder_input_ids`, available on all model interfaces, to set the starting canvas.

```py
initial_estimate = ... # a tensor with shape (bsz, 256)
model.generate(**inputs, max_new_tokens=256, decoder_input_ids=initial_estimate)
```

## DiffusionGemmaTextConfig

[[autodoc]] DiffusionGemmaTextConfig
Expand All @@ -96,6 +112,7 @@ model.generate(**inputs, max_new_tokens=256, streamer=streamer)
## DiffusionGemmaGenerationMixin

[[autodoc]] DiffusionGemmaGenerationMixin
- generate

## DiffusionGemmaGenerationConfig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from ...generation.streamers import BaseStreamer
from ...modeling_outputs import ModelOutput
from ...utils import auto_docstring, logging
from ...utils import logging


logger = logging.get_logger(__name__)
Expand All @@ -55,7 +55,7 @@
class DiffusionGemmaGenerationConfig(GenerationConfig):
# no-format
"""
A GenerationConfig class with paremeterization custom to DiffusionGemma `generate`.
A GenerationConfig class with parameterization customized for [`DiffusionGemmaGenerationMixin.generate`].

Args:
> Parameters that control the length of the output
Expand Down Expand Up @@ -236,7 +236,6 @@ def from_model_config(self, *args, **kwargs):
raise NotImplementedError("DiffusionGemmaGenerationConfig does not support `from_model_config`")


@auto_docstring
@dataclass
class DiffusionGemmaGenerationOutput(ModelOutput):
"""
Expand Down Expand Up @@ -340,6 +339,7 @@ class EntropyBoundSampler:
renoises non-accepted tokens.

Here is a rough sketch of how the sampler loop works:
```
+-----------------------+
| Canvas initialization |
| x_T ∈ U(V) |
Expand All @@ -363,6 +363,7 @@ class EntropyBoundSampler:
| +-------------------------+
+---------------------------------------| Next canvas x_{t-1} |
+-------------------------+
```

Args:
config (`EntropyBoundSamplerConfig`):
Expand Down Expand Up @@ -553,6 +554,7 @@ def generate(

It contains an outer loop doing autoregressive generation of canvases (blocks of tokens), and an inner
loop doing diffusion on each canvas. The algorithm works roughly as follows:
```
1. Autoregressive canvas generation loop:
a. Encode all previous tokens using the encoder, to get the KV cache.
b. Prepare data for the new denoising loop
Expand All @@ -567,6 +569,7 @@ def generate(
e. Check if any autoregressive stopping criteria are met, and break the outer loop if all sequences have
met them. Replaces generated tokens in finished sequences by pad.
f. Prepare tensors for the next block
```

Parameters:
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Expand Down Expand Up @@ -598,9 +601,10 @@ def generate(
used with AR LLMs.
kwargs (`dict[str, Any]`, *optional*):
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
forwarded to the `forward` function of the model. For instance, you can set the starting canvas with
`decoder_input_ids`.

Returns:
Return:
[`DiffusionGemmaGenerationOutput`]: a `ModelOutput` instance containing the generated text (`sequences`),
as well as other optional outputs.

Expand All @@ -610,11 +614,11 @@ def generate(
>>> from transformers import DiffusionGemmaForBlockDiffusion, AutoProcessor, TextDiffusionStreamer

>>> model = DiffusionGemmaForBlockDiffusion.from_pretrained(
... "CHECKPOINT", device_map="auto",
... "google/diffusiongemma-26B-A4B-it", device_map="auto",
>>> )

>>> chat = [{"role": "user", "content": "Why is the sky blue?"},]
>>> processor = AutoProcessor.from_pretrained("CHECKPOINT")
>>> processor = AutoProcessor.from_pretrained("google/diffusiongemma-26B-A4B-it")
>>> input_ids = processor.apply_chat_template(chat, tokenize=True, return_tensors="pt")

>>> streamer = TextDiffusionStreamer(tokenizer=processor.tokenizer)
Expand Down
Loading