Skip to content

generate() produces incoherent output when inputs_embeds has length 1 #41863

@tyarkoni

Description

@tyarkoni

System Info

  • transformers version: 4.57.1
  • Platform: Linux-6.8.0-51-generic-x86_64-with-glibc2.35
  • Python version: 3.10.18
  • Huggingface_hub version: 0.35.3
  • Safetensors version: 0.6.2
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H100 PCIe

Who can help?

@xadupre (original author of code in question), @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to Reproduce

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load any causal LM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model.eval()

# Create a single embedding (simulating a prefix like a style token)
single_embedding = torch.randn(1, 1, 768)  # [batch=1, seq_len=1, hidden_dim=768]

# Generate with length-1 inputs_embeds
with torch.no_grad():
    outputs = model.generate(
        inputs_embeds=single_embedding,
        max_length=20,
        do_sample=True,
        temperature=1.0,
        pad_token_id=tokenizer.eos_token_id,
    )

# Decode and observe gibberish
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)
# Output will be incoherent repetitive tokens like "the the the I if the..."

Comparison with working case (length ≥ 2):

# Add a second embedding (e.g., BOS token embedding)
bos_embedding = model.get_input_embeddings()(torch.tensor([[tokenizer.bos_token_id]]))
two_embeddings = torch.cat([single_embedding, bos_embedding], dim=1)  # [1, 2, 768]

# Generate with length-2 inputs_embeds
with torch.no_grad():
    outputs = model.generate(
        inputs_embeds=two_embeddings,
        max_length=20,
        do_sample=True,
        temperature=1.0,
        pad_token_id=tokenizer.eos_token_id,
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)
# Output will be coherent

Expected behavior

Expected Behavior

generate() should produce coherent text regardless of whether inputs_embeds has length 1 or length > 1, as long as the embeddings are valid.

Actual Behavior

With length-1 inputs_embeds, generate() produces incoherent, repetitive gibberish that appears to be high-frequency tokens without proper conditioning on previous context.

Suggested Fix

The _cache_dependant_input_preparation method needs to properly handle the transition from inputs_embeds mode to input_ids mode after the first generation step. Specifically:

  1. After the first token is generated from inputs_embeds, set inputs_embeds = None for subsequent iterations
  2. Or, maintain proper bookkeeping so that the embeddings prefix is correctly tracked throughout the autoregressive loop
  3. Or, ensure Exception 4 logic properly handles the case where we've transitioned from embeddings to token IDs

Metadata

Metadata

Assignees

Labels

WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progressbug

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions