Skip to content

Commit 0a74e9d

Browse files
authored
[Gemma3n] Fix audio batching (vllm-project#24052)
Signed-off-by: NickLucche <[email protected]>
1 parent 8bd5844 commit 0a74e9d

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

examples/online_serving/openai_chat_completion_client_for_multimodal.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,52 @@ def run_audio(model: str) -> None:
266266
print("Chat completion output from base64 encoded audio:", result)
267267

268268

269+
def run_multi_audio(model: str) -> None:
270+
from vllm.assets.audio import AudioAsset
271+
272+
# Two different audios to showcase batched inference.
273+
audio_url = AudioAsset("winning_call").url
274+
audio_base64 = encode_base64_content_from_url(audio_url)
275+
audio_url2 = AudioAsset("azacinto_foscolo").url
276+
audio_base64_2 = encode_base64_content_from_url(audio_url2)
277+
278+
# OpenAI-compatible schema (`input_audio`)
279+
chat_completion_from_base64 = client.chat.completions.create(
280+
messages=[
281+
{
282+
"role": "user",
283+
"content": [
284+
{"type": "text", "text": "Are these two audios the same?"},
285+
{
286+
"type": "input_audio",
287+
"input_audio": {
288+
"data": audio_base64,
289+
"format": "wav",
290+
},
291+
},
292+
{
293+
"type": "input_audio",
294+
"input_audio": {
295+
"data": audio_base64_2,
296+
"format": "wav",
297+
},
298+
},
299+
],
300+
}
301+
],
302+
model=model,
303+
max_completion_tokens=64,
304+
)
305+
306+
result = chat_completion_from_base64.choices[0].message.content
307+
print("Chat completion output from input audio:", result)
308+
309+
269310
example_function_map = {
270311
"text-only": run_text_only,
271312
"single-image": run_single_image,
272313
"multi-image": run_multi_image,
314+
"multi-audio": run_multi_audio,
273315
"video": run_video,
274316
"audio": run_audio,
275317
}

vllm/model_executor/models/gemma3n_mm.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import torch
8+
# yapf: disable
89
from torch import nn
910
from transformers import AutoModel, BatchFeature
1011
from transformers.models.gemma3n import (Gemma3nAudioConfig,
@@ -30,7 +31,6 @@
3031
MultiModalKwargsItems)
3132
from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
3233
MultiModalDataParser)
33-
# yapf: disable
3434
from vllm.multimodal.processing import (BaseMultiModalProcessor,
3535
BaseProcessingInfo,
3636
MultiModalPromptUpdates,
@@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict):
6262

6363

6464
class Gemma3nAudioInputs(TypedDict):
65-
input_features: torch.Tensor
65+
input_features: Union[torch.Tensor, list[torch.Tensor]]
66+
input_features_padded: torch.Tensor
6667
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
6768
input_features_mask: torch.Tensor
6869
"""Shape: `(batch_size * num_audio, seq_length)`"""
@@ -188,8 +189,13 @@ def _call_hf_processor(
188189
mm_kwargs,
189190
tok_kwargs,
190191
)
192+
191193
if 'input_features' in processed_outputs:
192-
# Avoid padding since we need the output of each item to be
194+
# Padding enables audio_tower to run in batched mode
195+
processed_outputs["input_features_padded"] = \
196+
processed_outputs["input_features"]
197+
198+
# Unpad features here since we need the output of each item to be
193199
# independent of other items for the cache to work correctly
194200
unpadded_features = [
195201
f[mask] for f, mask in zip(
@@ -206,9 +212,11 @@ def _get_mm_fields_config(
206212
hf_processor_mm_kwargs: Mapping[str, object],
207213
) -> Mapping[str, MultiModalFieldConfig]:
208214

209-
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
210-
input_features=MultiModalFieldConfig.batched("audio"),
211-
input_features_mask=MultiModalFieldConfig.batched("audio"))
215+
return dict(
216+
pixel_values=MultiModalFieldConfig.batched("image"),
217+
input_features=MultiModalFieldConfig.batched("audio"),
218+
input_features_padded=MultiModalFieldConfig.batched("audio"),
219+
input_features_mask=MultiModalFieldConfig.batched("audio"))
212220

213221
def _get_prompt_updates(
214222
self,
@@ -516,9 +524,14 @@ def _parse_and_validate_audio_input(
516524
if input_features_mask is None:
517525
return None
518526

527+
input_features_padded = kwargs.pop("input_features_padded", None)
528+
if input_features_padded is None:
529+
return None
530+
519531
return Gemma3nAudioInputs(
520532
input_features=input_features,
521533
input_features_mask=input_features_mask,
534+
input_features_padded=input_features_padded,
522535
)
523536

524537
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@@ -564,7 +577,8 @@ def _process_audio_input(
564577
audio_input: Gemma3nAudioInputs,
565578
) -> list[torch.Tensor]:
566579
assert self.audio_tower is not None
567-
input_features = audio_input["input_features"].squeeze(1)
580+
# Run on padded features to enable batching
581+
input_features = audio_input["input_features_padded"].squeeze(1)
568582
input_features_mask = audio_input["input_features_mask"].squeeze(1)
569583
audio_outputs, audio_mask = self.audio_tower(input_features,
570584
~input_features_mask)

0 commit comments

Comments
 (0)