5
5
6
6
import numpy as np
7
7
import torch
8
+ # yapf: disable
8
9
from torch import nn
9
10
from transformers import AutoModel , BatchFeature
10
11
from transformers .models .gemma3n import (Gemma3nAudioConfig ,
30
31
MultiModalKwargsItems )
31
32
from vllm .multimodal .parse import (ImageProcessorItems , MultiModalDataItems ,
32
33
MultiModalDataParser )
33
- # yapf: disable
34
34
from vllm .multimodal .processing import (BaseMultiModalProcessor ,
35
35
BaseProcessingInfo ,
36
36
MultiModalPromptUpdates ,
@@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict):
62
62
63
63
64
64
class Gemma3nAudioInputs (TypedDict ):
65
- input_features : torch .Tensor
65
+ input_features : Union [torch .Tensor , list [torch .Tensor ]]
66
+ input_features_padded : torch .Tensor
66
67
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
67
68
input_features_mask : torch .Tensor
68
69
"""Shape: `(batch_size * num_audio, seq_length)`"""
@@ -188,8 +189,13 @@ def _call_hf_processor(
188
189
mm_kwargs ,
189
190
tok_kwargs ,
190
191
)
192
+
191
193
if 'input_features' in processed_outputs :
192
- # Avoid padding since we need the output of each item to be
194
+ # Padding enables audio_tower to run in batched mode
195
+ processed_outputs ["input_features_padded" ] = \
196
+ processed_outputs ["input_features" ]
197
+
198
+ # Unpad features here since we need the output of each item to be
193
199
# independent of other items for the cache to work correctly
194
200
unpadded_features = [
195
201
f [mask ] for f , mask in zip (
@@ -206,9 +212,11 @@ def _get_mm_fields_config(
206
212
hf_processor_mm_kwargs : Mapping [str , object ],
207
213
) -> Mapping [str , MultiModalFieldConfig ]:
208
214
209
- return dict (pixel_values = MultiModalFieldConfig .batched ("image" ),
210
- input_features = MultiModalFieldConfig .batched ("audio" ),
211
- input_features_mask = MultiModalFieldConfig .batched ("audio" ))
215
+ return dict (
216
+ pixel_values = MultiModalFieldConfig .batched ("image" ),
217
+ input_features = MultiModalFieldConfig .batched ("audio" ),
218
+ input_features_padded = MultiModalFieldConfig .batched ("audio" ),
219
+ input_features_mask = MultiModalFieldConfig .batched ("audio" ))
212
220
213
221
def _get_prompt_updates (
214
222
self ,
@@ -516,9 +524,14 @@ def _parse_and_validate_audio_input(
516
524
if input_features_mask is None :
517
525
return None
518
526
527
+ input_features_padded = kwargs .pop ("input_features_padded" , None )
528
+ if input_features_padded is None :
529
+ return None
530
+
519
531
return Gemma3nAudioInputs (
520
532
input_features = input_features ,
521
533
input_features_mask = input_features_mask ,
534
+ input_features_padded = input_features_padded ,
522
535
)
523
536
524
537
def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
@@ -564,7 +577,8 @@ def _process_audio_input(
564
577
audio_input : Gemma3nAudioInputs ,
565
578
) -> list [torch .Tensor ]:
566
579
assert self .audio_tower is not None
567
- input_features = audio_input ["input_features" ].squeeze (1 )
580
+ # Run on padded features to enable batching
581
+ input_features = audio_input ["input_features_padded" ].squeeze (1 )
568
582
input_features_mask = audio_input ["input_features_mask" ].squeeze (1 )
569
583
audio_outputs , audio_mask = self .audio_tower (input_features ,
570
584
~ input_features_mask )
0 commit comments