Skip to content

Commit 4c631db

Browse files
committed
Merge branch 'main' into lwilkinson/dbo-prefill
Signed-off-by: Tyler Michael Smith <[email protected]>
2 parents 356ddcb + b6a136b commit 4c631db

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2759
-973
lines changed

docs/contributing/benchmarks.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ vllm bench serve \
680680
--save-result \
681681
--result-dir ~/vllm_benchmark_results \
682682
--save-detailed \
683-
--endpoint /v1/chat/completion
683+
--endpoint /v1/chat/completions
684684
```
685685

686686
##### Videos (ShareGPT4Video)
@@ -707,7 +707,7 @@ vllm bench serve \
707707
--save-result \
708708
--result-dir ~/vllm_benchmark_results \
709709
--save-detailed \
710-
--endpoint /v1/chat/completion
710+
--endpoint /v1/chat/completions
711711
```
712712

713713
##### Synthetic Random Images (random-mm)

docs/features/disagg_prefill.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ Now supports 5 types of connectors:
3131
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
3232
```
3333

34+
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
35+
36+
```bash
37+
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
38+
```
39+
3440
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
3541

3642
```bash

docs/features/tool_calling.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,15 @@ Supported models:
319319

320320
Flags: `--tool-call-parser glm45`
321321

322+
### Qwen3-Coder Models (`qwen3_xml`)
323+
324+
Supported models:
325+
326+
* `Qwen/Qwen3-480B-A35B-Instruct`
327+
* `Qwen/Qwen3-Coder-30B-A3B-Instruct`
328+
329+
Flags: `--tool-call-parser qwen3_xml`
330+
322331
### Models with Pythonic Tool Calls (`pythonic`)
323332

324333
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.

docs/serving/expert_parallel_deployment.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
193193

194194
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
195195

196-
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`
196+
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
197197

198198
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
199199

tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py

Lines changed: 203 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55

66
import pytest
77

8+
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
9+
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
10+
Hermes2ProToolParser)
11+
from vllm.transformers_utils.tokenizer import AnyTokenizer
12+
813
from ....utils import RemoteOpenAIServer
914

1015
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@@ -37,7 +42,7 @@
3742
},
3843
"unit": {
3944
"type": "string",
40-
"enum": ["celsius", "fahrenheit"]
45+
"enum": ["celsius", "fahrenheit"],
4146
},
4247
},
4348
"required": ["location"],
@@ -75,7 +80,7 @@
7580
"user",
7681
"content":
7782
"Hi! Do you have any detailed information about the product id "
78-
"7355608 and inserted true?"
83+
"7355608 and inserted true?",
7984
}]
8085

8186

@@ -144,8 +149,8 @@ async def test_streaming_tool_call():
144149
if tool_chunk.function.name:
145150
tool_call_chunks[index]["name"] += tool_chunk.function.name
146151
if tool_chunk.function.arguments:
147-
tool_call_chunks[index][
148-
"arguments"] += tool_chunk.function.arguments
152+
tool_call_chunks[index]["arguments"] += (
153+
tool_chunk.function.arguments)
149154

150155
assert len(tool_call_chunks) == 1
151156
reconstructed_tool_call = tool_call_chunks[0]
@@ -234,8 +239,8 @@ async def test_streaming_product_tool_call():
234239
if tool_chunk.function.name:
235240
tool_call_chunks[index]["name"] += tool_chunk.function.name
236241
if tool_chunk.function.arguments:
237-
tool_call_chunks[index][
238-
"arguments"] += tool_chunk.function.arguments
242+
tool_call_chunks[index]["arguments"] += (
243+
tool_chunk.function.arguments)
239244

240245
assert len(tool_call_chunks) == 1
241246
reconstructed_tool_call = tool_call_chunks[0]
@@ -258,3 +263,195 @@ async def test_streaming_product_tool_call():
258263
print("\n[Streaming Product Test Passed]")
259264
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
260265
print(f"Reconstructed Arguments: {arguments}")
266+
267+
268+
@pytest.fixture
269+
def qwen_tokenizer() -> AnyTokenizer:
270+
from vllm.transformers_utils.tokenizer import get_tokenizer
271+
272+
return get_tokenizer("Qwen/Qwen3-32B")
273+
274+
275+
@pytest.fixture
276+
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
277+
return Hermes2ProToolParser(qwen_tokenizer)
278+
279+
280+
@pytest.fixture
281+
def any_chat_request() -> ChatCompletionRequest:
282+
return ChatCompletionRequest(
283+
seed=42,
284+
model="Qwen/Qwen3-32B",
285+
messages=[],
286+
)
287+
288+
289+
def test_hermes_parser_streaming_just_forward_text(
290+
qwen_tokenizer: AnyTokenizer,
291+
hermes_parser: Hermes2ProToolParser,
292+
any_chat_request: ChatCompletionRequest,
293+
) -> None:
294+
text = (
295+
"""This is some prior text that has nothing to do with tool calling."""
296+
)
297+
tokens = qwen_tokenizer.encode(text)
298+
previous_text = ""
299+
delta_messages = []
300+
for token in tokens:
301+
delta_text = qwen_tokenizer.decode([token])
302+
current_text = previous_text + delta_text
303+
delta = hermes_parser.extract_tool_calls_streaming(
304+
previous_text=previous_text,
305+
current_text=current_text,
306+
delta_text=delta_text,
307+
previous_token_ids=[],
308+
current_token_ids=[],
309+
delta_token_ids=[],
310+
request=any_chat_request,
311+
)
312+
previous_text = current_text
313+
delta_messages.append(delta)
314+
315+
for delta in delta_messages:
316+
assert delta is not None
317+
assert not delta.tool_calls
318+
319+
print(delta_messages)
320+
assert "".join([delta.content for delta in delta_messages]) == text
321+
322+
323+
def test_hermes_parser_streaming_failure_case_bug_19056(
324+
qwen_tokenizer: AnyTokenizer,
325+
hermes_parser: Hermes2ProToolParser,
326+
any_chat_request: ChatCompletionRequest,
327+
) -> None:
328+
text = """<tool_call>
329+
{"name": "final_answer", "arguments": {"trigger": true}}
330+
</tool_call>"""
331+
tokens = qwen_tokenizer.encode(text)
332+
previous_text = ""
333+
delta_messages = []
334+
for token in tokens:
335+
text = qwen_tokenizer.decode([token])
336+
current_text = previous_text + text
337+
delta = hermes_parser.extract_tool_calls_streaming(
338+
previous_text=previous_text,
339+
current_text=current_text,
340+
delta_text=text,
341+
previous_token_ids=[],
342+
current_token_ids=[],
343+
delta_token_ids=[],
344+
request=any_chat_request,
345+
)
346+
previous_text = current_text
347+
if delta is not None:
348+
delta_messages.append(delta)
349+
350+
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
351+
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
352+
for delta in delta_messages)
353+
assert tool_call_args == '{"trigger": true}'
354+
355+
356+
def test_hermes_parser_streaming(
357+
qwen_tokenizer: AnyTokenizer,
358+
hermes_parser: Hermes2ProToolParser,
359+
any_chat_request: ChatCompletionRequest,
360+
) -> None:
361+
text = '<tool_call>\
362+
{"name": "get_current_temperature",\
363+
"arguments": {"location":\
364+
"San Francisco, California, United States", "unit": "celsius"}}\
365+
</tool_call>'
366+
367+
tokens = qwen_tokenizer.encode(text)
368+
previous_text = ""
369+
delta_messages = []
370+
for token in tokens:
371+
text = qwen_tokenizer.decode([token])
372+
current_text = previous_text + text
373+
delta = hermes_parser.extract_tool_calls_streaming(
374+
previous_text=previous_text,
375+
current_text=current_text,
376+
delta_text=text,
377+
previous_token_ids=[],
378+
current_token_ids=[],
379+
delta_token_ids=[],
380+
request=any_chat_request,
381+
)
382+
previous_text = current_text
383+
if delta is not None:
384+
delta_messages.append(delta)
385+
print(delta_messages)
386+
assert (delta_messages[0].tool_calls[0].function.name ==
387+
"get_current_temperature")
388+
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
389+
for delta in delta_messages)
390+
assert tool_call_args == (
391+
'{"location":"San Francisco, California, United States", '
392+
'"unit": "celsius"}')
393+
394+
395+
def test_hermes_parser_non_streaming_no_tool_call(
396+
hermes_parser: Hermes2ProToolParser,
397+
any_chat_request: ChatCompletionRequest,
398+
) -> None:
399+
text = """This is not a tool call."""
400+
tool_call = hermes_parser.extract_tool_calls(
401+
model_output=text,
402+
request=any_chat_request,
403+
)
404+
405+
assert tool_call is not None
406+
assert not tool_call.tools_called
407+
408+
409+
def test_hermes_parser_non_streaming_tool_call_between_tags(
410+
hermes_parser: Hermes2ProToolParser,
411+
any_chat_request: ChatCompletionRequest,
412+
) -> None:
413+
text = """<tool_call>
414+
{"name": "final_answer", "arguments": {"trigger": true}}
415+
</tool_call>"""
416+
tool_call = hermes_parser.extract_tool_calls(
417+
model_output=text,
418+
request=any_chat_request,
419+
)
420+
421+
assert tool_call is not None
422+
assert tool_call.tools_called
423+
assert tool_call.tool_calls[0].function.name == "final_answer"
424+
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
425+
426+
427+
def test_hermes_parser_non_streaming_tool_call_until_eos(
428+
hermes_parser: Hermes2ProToolParser,
429+
any_chat_request: ChatCompletionRequest,
430+
) -> None:
431+
text = """<tool_call>
432+
{"name": "final_answer", "arguments": {"trigger": true}}"""
433+
tool_call = hermes_parser.extract_tool_calls(
434+
model_output=text,
435+
request=any_chat_request,
436+
)
437+
438+
assert tool_call is not None
439+
assert tool_call.tools_called
440+
assert tool_call.tool_calls[0].function.name == "final_answer"
441+
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
442+
443+
444+
def test_hermes_parser_non_streaming_tool_call_invalid_json(
445+
hermes_parser: Hermes2ProToolParser,
446+
any_chat_request: ChatCompletionRequest,
447+
) -> None:
448+
# Missing closing brace to trigger exception
449+
text = """<tool_call>
450+
{"name": "final_answer", "arguments": {"trigger": true}"""
451+
tool_call = hermes_parser.extract_tool_calls(
452+
model_output=text,
453+
request=any_chat_request,
454+
)
455+
456+
assert tool_call is not None
457+
assert not tool_call.tools_called

tests/kernels/attention/test_attention_selector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def generate_params():
6767
return params
6868

6969

70-
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
7170
@pytest.mark.parametrize("device, name, use_mla, block_size",
7271
generate_params())
7372
def test_env(
@@ -189,7 +188,7 @@ def test_env(
189188
# FlashMLA only supports block_size == 64
190189
pytest.skip("FlashMLA only supports block_size 64")
191190
else:
192-
from vllm.attention.backends.flashmla import (
191+
from vllm.v1.attention.backends.mla.flashmla import ( # noqa: E501
193192
is_flashmla_supported)
194193
is_supported, _ = is_flashmla_supported()
195194
if not is_supported:

tests/lora/test_layers.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def populate_loras(
164164
weight=layer_weights,
165165
generate_embeddings_tensor=generate_embeddings_tensor,
166166
)
167-
sublora.lora_b = sublora.lora_b[:, (sublora_len *
168-
i):(sublora_len * (i + 1))]
167+
sublora.lora_b = sublora.lora_b[(sublora_len *
168+
i):(sublora_len * (i + 1)), :]
169169
sublora.optimize()
170170
subloras.append(sublora)
171171

@@ -304,9 +304,9 @@ def create_random_embedding_layer():
304304
result = embedding(input_)
305305
after_a = F.embedding(
306306
input_,
307-
lora.lora_a,
307+
lora.lora_a.T,
308308
)
309-
result += (after_a @ lora.lora_b)
309+
result += (after_a @ lora.lora_b.T)
310310
expected_results.append(result)
311311
expected_result = torch.cat(expected_results)
312312

@@ -445,9 +445,9 @@ def create_random_embedding_layer():
445445
result = expanded_embedding(input_)
446446
after_a = F.embedding(
447447
original_input_,
448-
lora.lora_a,
448+
lora.lora_a.T,
449449
)
450-
result += (after_a @ lora.lora_b)
450+
result += (after_a @ lora.lora_b.T)
451451
expected_results.append(result)
452452
expected_result = torch.cat(expected_results)
453453

@@ -575,7 +575,7 @@ def _pretest():
575575
lm_head=linear,
576576
embedding_bias=None)
577577
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
578-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
578+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
579579
expected_results.append(result)
580580
expected_result = torch.cat(expected_results)
581581
logits_processor.org_vocab_size = vocab_size
@@ -692,9 +692,10 @@ def create_random_linear_replicated_layer():
692692

693693
expected_results: list[torch.Tensor] = []
694694
for input_, lora_id in zip(inputs, prompt_mapping):
695+
695696
lora = lora_dict[lora_id]
696697
result = linear(input_)[0]
697-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
698+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
698699
expected_results.append(result)
699700
expected_result = torch.cat(expected_results)
700701

@@ -817,7 +818,7 @@ def create_random_linear_parallel_layer():
817818
for input_, lora_id in zip(inputs, prompt_mapping):
818819
lora = lora_dict[lora_id]
819820
result = linear(input_)[0]
820-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
821+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
821822
expected_results.append(result)
822823
expected_result = torch.cat(expected_results)
823824

@@ -965,9 +966,10 @@ class FakeConfig:
965966
result = linear(input_)[0]
966967
subloras = sublora_dict[lora_id]
967968
for i, sublora in enumerate(subloras):
968-
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
969-
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
970-
sublora.scaling)
969+
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
970+
(i + 1)] += (
971+
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
972+
sublora.scaling)
971973
expected_results.append(result)
972974
expected_result = torch.cat(expected_results)
973975

0 commit comments

Comments
 (0)