Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/transformers/models/janus/modeling_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ def generate(
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
attention_mask=attention_mask,
expand_size=generation_config.num_return_sequences,
expand_size=generation_config.num_return_sequences or 1,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we just need to call _prepare_generation_config in the very beginning, it is missing

**model_kwargs,
)

Expand All @@ -1298,6 +1298,17 @@ def generate(
attention_mask = attention_mask.repeat(2, 1)
model_kwargs["attention_mask"] = attention_mask

# Ensure generation_kwargs exists with boi_token_id

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob this and rest will be resolved after calling _prepare_generation_config

if not hasattr(generation_config, "generation_kwargs") or generation_config.generation_kwargs is None:
generation_config.generation_kwargs = {}
if "boi_token_id" not in generation_config.generation_kwargs:
# Default boi_token_id - usually the image_token_id from config
generation_config.generation_kwargs["boi_token_id"] = getattr(self.config, "image_token_id", 0)

# Ensure pad_token_id is set
if generation_config.pad_token_id is None:
generation_config.pad_token_id = getattr(self.config, "pad_token_id", 0)

# Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
Expand All @@ -1310,12 +1321,18 @@ def generate(

if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.
# Need enough space for: input sequence + num_image_tokens iterations + safety margin
# The loop runs num_image_tokens times, starting from seq_len position
max_length = generation_config.max_length
min_cache_len = seq_len + num_image_tokens + 100 # Ensure enough buffer
if max_length is None:
max_length = min_cache_len
model_kwargs["past_key_values"] = self._prepare_static_cache(
cache_implementation=generation_config.cache_implementation or "static",
# batch_size should account for both conditional/unconditional input; hence multiplied by 2.
batch_size=batch_size * 2,
# we should have at least a cache len of seq_len + num_image_tokens.
max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
# we should have at least a cache len of seq_len + num_image_tokens + buffer.
max_cache_len=max(max_length, min_cache_len),
model_kwargs=model_kwargs,
)

Expand Down
23 changes: 20 additions & 3 deletions src/transformers/models/janus/modular_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def generate(
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
attention_mask=attention_mask,
expand_size=generation_config.num_return_sequences,
expand_size=generation_config.num_return_sequences or 1,
**model_kwargs,
)

Expand All @@ -1073,6 +1073,17 @@ def generate(
attention_mask = attention_mask.repeat(2, 1)
model_kwargs["attention_mask"] = attention_mask

# Ensure generation_kwargs exists with boi_token_id
if not hasattr(generation_config, "generation_kwargs") or generation_config.generation_kwargs is None:
generation_config.generation_kwargs = {}
if "boi_token_id" not in generation_config.generation_kwargs:
# Default boi_token_id - usually the image_token_id from config
generation_config.generation_kwargs["boi_token_id"] = getattr(self.config, "image_token_id", 0)

# Ensure pad_token_id is set
if generation_config.pad_token_id is None:
generation_config.pad_token_id = getattr(self.config, "pad_token_id", 0)

# Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
Expand All @@ -1085,12 +1096,18 @@ def generate(

if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.
# Need enough space for: input sequence + num_image_tokens iterations + safety margin
# The loop runs num_image_tokens times, starting from seq_len position
max_length = generation_config.max_length
min_cache_len = seq_len + num_image_tokens + 100 # Ensure enough buffer
if max_length is None:
max_length = min_cache_len
model_kwargs["past_key_values"] = self._prepare_static_cache(
cache_implementation=generation_config.cache_implementation or "static",
# batch_size should account for both conditional/unconditional input; hence multiplied by 2.
batch_size=batch_size * 2,
# we should have at least a cache len of seq_len + num_image_tokens.
max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
# we should have at least a cache len of seq_len + num_image_tokens + buffer.
max_cache_len=max(max_length, min_cache_len),
model_kwargs=model_kwargs,
)

Expand Down
Loading