Skip to content

Commit 6d6c6b0

Browse files
authored
[New Model]: google/embeddinggemma-300m (vllm-project#24318)
Signed-off-by: wang.yuqi <[email protected]>
1 parent 53b19cc commit 6d6c6b0

File tree

9 files changed

+73
-29
lines changed

9 files changed

+73
-29
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
440440
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
441441
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ |
442442
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ |
443+
| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ |
443444
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
444445
| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ |
445446
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ |

tests/models/language/pooling/mteb_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import pytest
1111
import requests
1212

13-
from tests.models.utils import EmbedModelInfo, RerankModelInfo
13+
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
14+
check_embeddings_close)
1415

1516
# Most embedding models on the STS12 task (See #17175):
1617
# - Model implementation and minor changes in tensor dtype
@@ -163,12 +164,14 @@ def mteb_test_embed_models(hf_runner,
163164
model_info: EmbedModelInfo,
164165
vllm_extra_kwargs=None,
165166
hf_model_callback=None,
166-
atol=MTEB_RERANK_TOL):
167+
atol=MTEB_EMBED_TOL):
167168
if not model_info.enable_test:
168169
# A model family has many models with the same architecture,
169170
# and we don't need to test each one.
170171
pytest.skip("Skipping test.")
171172

173+
example_prompts = ["The chef prepared a delicious meal."]
174+
172175
vllm_extra_kwargs = vllm_extra_kwargs or {}
173176
vllm_extra_kwargs["dtype"] = model_info.dtype
174177

@@ -191,6 +194,7 @@ def mteb_test_embed_models(hf_runner,
191194
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
192195
MTEB_EMBED_TASKS)
193196
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
197+
vllm_outputs = vllm_model.embed(example_prompts)
194198

195199
if model_info.mteb_score is None:
196200
with hf_runner(model_info.name,
@@ -202,6 +206,16 @@ def mteb_test_embed_models(hf_runner,
202206

203207
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
204208
st_dtype = next(hf_model.model.parameters()).dtype
209+
210+
# Test embed_dims and whether to use normalize
211+
hf_outputs = hf_model.encode(example_prompts)
212+
check_embeddings_close(
213+
embeddings_0_lst=hf_outputs,
214+
embeddings_1_lst=vllm_outputs,
215+
name_0="hf",
216+
name_1="vllm",
217+
tol=1e-2,
218+
)
205219
else:
206220
st_main_score = model_info.mteb_score
207221
st_dtype = "Constant"

tests/models/language/pooling/test_st_projector.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import pytest
44

5-
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
5+
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
6+
LASTPoolingEmbedModelInfo)
67
from .mteb_utils import mteb_test_embed_models
78

89
# ST models with projector (Dense) layers
@@ -13,6 +14,10 @@
1314
mteb_score=0.688611955,
1415
enable_test=True,
1516
),
17+
LASTPoolingEmbedModelInfo("google/embeddinggemma-300m",
18+
architecture="Gemma3TextModel",
19+
mteb_score=0.7473819294684156,
20+
enable_test=True)
1621
]
1722

1823

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def check_available_online(
352352
# [Text-only]
353353
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
354354
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
355+
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
355356
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
356357
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
357358
trust_remote_code=True),

vllm/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,6 +2750,8 @@ def compute_hash(self) -> str:
27502750
_FLOAT16_NOT_SUPPORTED_MODELS = {
27512751
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
27522752
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
2753+
"gemma3_text":
2754+
"Numerical instability. Please use bfloat16 or float32 instead.",
27532755
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
27542756
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
27552757
}

vllm/model_executor/models/adapters.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
4949
if not dense_modules:
5050
return None
5151

52-
module = dense_modules[0]
53-
folder = module.get("path", "")
52+
layers = []
53+
for module in dense_modules:
54+
folder = module.get("path", "")
55+
56+
config_path = f"{folder}/config.json" if folder else "config.json"
57+
layer_config = get_hf_file_to_dict(config_path, model_config.model,
58+
model_config.revision)
59+
if not layer_config:
60+
continue
5461

55-
config_path = f"{folder}/config.json" if folder else "config.json"
56-
layer_config = get_hf_file_to_dict(config_path, model_config.model,
57-
model_config.revision)
58-
if not layer_config:
59-
return None
62+
linear = nn.Linear(layer_config.get("in_features", 768),
63+
layer_config.get("out_features", 768),
64+
bias=layer_config.get("bias", True),
65+
dtype=torch.float32)
6066

61-
linear = nn.Linear(layer_config.get("in_features", 768),
62-
layer_config.get("out_features", 768),
63-
bias=layer_config.get("bias", True),
64-
dtype=torch.float32)
67+
if not _load_dense_weights(linear, folder, model_config):
68+
continue
6569

66-
if _load_dense_weights(linear, folder, model_config):
67-
layers = [linear]
70+
layers.append(linear)
6871
if act_name := layer_config.get("activation_function"):
6972
layers.append(get_act_fn(act_name))
70-
return nn.Sequential(*layers).to(dtype=torch.float32)
71-
73+
return nn.Sequential(*layers).to(dtype=torch.float32)
7274
except Exception:
7375
logger.exception("ST projector loading failed")
7476

vllm/model_executor/models/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
2424
raise NotImplementedError
2525

2626

27+
class Gemma3TextModelConfig:
28+
29+
@staticmethod
30+
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
31+
hf_config = vllm_config.model_config.hf_config
32+
hf_config.is_causal = not hf_config.use_bidirectional_attention
33+
34+
2735
class GteNewModelConfig(VerifyAndUpdateConfig):
2836

2937
@staticmethod
@@ -409,6 +417,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
409417
"GteModel": SnowflakeGteNewModelConfig,
410418
"GteNewModel": GteNewModelConfig,
411419
"GteNewForSequenceClassification": GteNewModelConfig,
420+
"Gemma3TextModel": Gemma3TextModelConfig,
412421
"NomicBertModel": NomicBertModelConfig,
413422
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
414423
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,

vllm/model_executor/models/gemma3.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torch import nn
2525
from transformers import Gemma3TextConfig
2626

27-
from vllm.attention import Attention
27+
from vllm.attention import Attention, AttentionType
2828
from vllm.compilation.decorators import support_torch_compile
2929
from vllm.config import CacheConfig, VllmConfig
3030
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -44,6 +44,7 @@
4444
from vllm.model_executor.sampling_metadata import SamplingMetadata
4545
from vllm.sequence import IntermediateTensors
4646

47+
from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
4748
from .interfaces import SupportsLoRA, SupportsPP
4849
from .utils import (AutoWeightsLoader, extract_layer_index,
4950
is_pp_missing_parameter,
@@ -169,16 +170,24 @@ def __init__(self,
169170
rope_scaling=self.rope_scaling,
170171
)
171172

172-
# Initialize the attention.
173-
self.attn = Attention(self.num_heads,
174-
self.head_dim,
175-
self.scaling,
176-
num_kv_heads=self.num_kv_heads,
177-
cache_config=cache_config,
178-
quant_config=quant_config,
179-
logits_soft_cap=attn_logits_soft_cap,
180-
per_layer_sliding_window=sliding_window,
181-
prefix=f"{prefix}.attn")
173+
if getattr(config, "is_causal", True):
174+
attn_type = AttentionType.DECODER
175+
else:
176+
attn_type = AttentionType.ENCODER_ONLY
177+
178+
attn_cls = (EncoderOnlyAttention
179+
if attn_type == AttentionType.ENCODER_ONLY else Attention)
180+
181+
self.attn = attn_cls(self.num_heads,
182+
self.head_dim,
183+
self.scaling,
184+
num_kv_heads=self.num_kv_heads,
185+
cache_config=cache_config,
186+
quant_config=quant_config,
187+
attn_type=attn_type,
188+
logits_soft_cap=attn_logits_soft_cap,
189+
per_layer_sliding_window=sliding_window,
190+
prefix=f"{prefix}.attn")
182191

183192
def forward(
184193
self,

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
"BertModel": ("bert", "BertEmbeddingModel"),
156156
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
157157
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
158+
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
158159
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
159160
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
160161
"GritLM": ("gritlm", "GritLM"),

0 commit comments

Comments
 (0)