diff --git a/docs/source/en/model_doc/diffusion_gemma.md b/docs/source/en/model_doc/diffusion_gemma.md index 1f580b059adf..df0b57bc8464 100644 --- a/docs/source/en/model_doc/diffusion_gemma.md +++ b/docs/source/en/model_doc/diffusion_gemma.md @@ -29,16 +29,19 @@ The encoder operates in a prefill capacity, processing the initial prompt and ge During inference, DiffusionGemma leverages multi-canvas sampling. Rather than generating one token at a time, the model iteratively denoises a full block of tokens using a diffusion sampler. Once a canvas is fully denoised, it is processed by the encoder and appended to the KV cache, after which the model generates the next canvas. This block-autoregressive approach facilitates text generation at higher speeds. -You can find the model card and checkpoint [here](https://huggingface.co/google/diffusiongemma-26B-A4B-it). +You can find the model card and checkpoint [here](https://huggingface.co/google/diffusiongemma-26B-A4B-it). You can find a visual guide to the model [here](https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-diffusiongemma). ## Usage examples -Despite it being a text diffusion model and having a custom generation loop, most of the interface is shared with other model that can generate text with `.generate()`. If you're using another `transformers` model in your app, you should be able to directly replace it with this model. +Despite it being a text diffusion model and having a custom generation loop, most of the interface is shared with other models that can generate text with [`DiffusionGemmaGenerationMixin.generate`]. If you're using another `transformers` model in your app, you should be able to directly replace it with this model. + +### Common caveats -Common caveats: - DiffusionGemma doesn't accept `use_cache`. It always uses a KV cache; - Support for common flags like `top_k` won't be available at release day, but will be added over time if they are compatible with text diffusion. +### Basic example + ```python from transformers import DiffusionGemmaForBlockDiffusion, AutoProcessor @@ -62,6 +65,8 @@ inputs = processor.apply_chat_template( return_dict=True, return_tensors="pt", add_generation_prompt=True, + # Add the following to enable thinking + # enable_thinking=True, ).to(model.device) input_len = inputs["input_ids"].shape[-1] @@ -71,6 +76,8 @@ output = model.generate(**inputs, max_new_tokens=256) print(processor.decode(output.sequences[0][input_len:], skip_special_tokens=True)) ``` +### Streaming + Like other models that can generate text, you can set a streamer class to stream text. Unlike other models, DiffusionGemma generates intermediate drafts before the final text. You can visualize them with `TextDiffusionStreamer` ```python @@ -81,6 +88,15 @@ streamer = TextDiffusionStreamer(tokenizer=processor.tokenizer) model.generate(**inputs, max_new_tokens=256, streamer=streamer) ``` +### Setting a starting denoising output + +The model is trained to iteratively refine blocks of 256 tokens. On some applications, it may be beneficial to provide a starting point for the decoder, rather than starting from random tokens. You can use the `decoder_input_ids`, available on all model interfaces, to set the starting canvas. + +```py +initial_estimate = ... # a tensor with shape (bsz, 256) +model.generate(**inputs, max_new_tokens=256, decoder_input_ids=initial_estimate) +``` + ## DiffusionGemmaTextConfig [[autodoc]] DiffusionGemmaTextConfig @@ -96,6 +112,7 @@ model.generate(**inputs, max_new_tokens=256, streamer=streamer) ## DiffusionGemmaGenerationMixin [[autodoc]] DiffusionGemmaGenerationMixin + - generate ## DiffusionGemmaGenerationConfig diff --git a/src/transformers/models/diffusion_gemma/generation_diffusion_gemma.py b/src/transformers/models/diffusion_gemma/generation_diffusion_gemma.py index 4672f301257b..edad72a5c6d9 100644 --- a/src/transformers/models/diffusion_gemma/generation_diffusion_gemma.py +++ b/src/transformers/models/diffusion_gemma/generation_diffusion_gemma.py @@ -45,7 +45,7 @@ ) from ...generation.streamers import BaseStreamer from ...modeling_outputs import ModelOutput -from ...utils import auto_docstring, logging +from ...utils import logging logger = logging.get_logger(__name__) @@ -55,7 +55,7 @@ class DiffusionGemmaGenerationConfig(GenerationConfig): # no-format """ - A GenerationConfig class with paremeterization custom to DiffusionGemma `generate`. + A GenerationConfig class with parameterization customized for [`DiffusionGemmaGenerationMixin.generate`]. Args: > Parameters that control the length of the output @@ -236,7 +236,6 @@ def from_model_config(self, *args, **kwargs): raise NotImplementedError("DiffusionGemmaGenerationConfig does not support `from_model_config`") -@auto_docstring @dataclass class DiffusionGemmaGenerationOutput(ModelOutput): """ @@ -340,6 +339,7 @@ class EntropyBoundSampler: renoises non-accepted tokens. Here is a rough sketch of how the sampler loop works: + ``` +-----------------------+ | Canvas initialization | | x_T ∈ U(V) | @@ -363,6 +363,7 @@ class EntropyBoundSampler: | +-------------------------+ +---------------------------------------| Next canvas x_{t-1} | +-------------------------+ + ``` Args: config (`EntropyBoundSamplerConfig`): @@ -553,6 +554,7 @@ def generate( It contains an outer loop doing autoregressive generation of canvases (blocks of tokens), and an inner loop doing diffusion on each canvas. The algorithm works roughly as follows: + ``` 1. Autoregressive canvas generation loop: a. Encode all previous tokens using the encoder, to get the KV cache. b. Prepare data for the new denoising loop @@ -567,6 +569,7 @@ def generate( e. Check if any autoregressive stopping criteria are met, and break the outer loop if all sequences have met them. Replaces generated tokens in finished sequences by pad. f. Prepare tensors for the next block + ``` Parameters: input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): @@ -598,9 +601,10 @@ def generate( used with AR LLMs. kwargs (`dict[str, Any]`, *optional*): Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. + forwarded to the `forward` function of the model. For instance, you can set the starting canvas with + `decoder_input_ids`. - Returns: + Return: [`DiffusionGemmaGenerationOutput`]: a `ModelOutput` instance containing the generated text (`sequences`), as well as other optional outputs. @@ -610,11 +614,11 @@ def generate( >>> from transformers import DiffusionGemmaForBlockDiffusion, AutoProcessor, TextDiffusionStreamer >>> model = DiffusionGemmaForBlockDiffusion.from_pretrained( - ... "CHECKPOINT", device_map="auto", + ... "google/diffusiongemma-26B-A4B-it", device_map="auto", >>> ) >>> chat = [{"role": "user", "content": "Why is the sky blue?"},] - >>> processor = AutoProcessor.from_pretrained("CHECKPOINT") + >>> processor = AutoProcessor.from_pretrained("google/diffusiongemma-26B-A4B-it") >>> input_ids = processor.apply_chat_template(chat, tokenize=True, return_tensors="pt") >>> streamer = TextDiffusionStreamer(tokenizer=processor.tokenizer)