Skip to content

Commit 3a0fba5

Browse files
authored
[V1][Spec Decode] Handle draft tokens beyond max_model_len (vllm-project#16087)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 299ebb6 commit 3a0fba5

File tree

7 files changed

+137
-15
lines changed

7 files changed

+137
-15
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def create_scheduler(
3030
use_kv_connector: bool = False,
3131
num_blocks: int = 10000,
3232
block_size: int = 16,
33+
max_model_len: Optional[int] = None,
3334
) -> Scheduler:
3435
'''Create scheduler under test.
3536
@@ -44,12 +45,15 @@ def create_scheduler(
4445
Returns:
4546
:class:`Scheduler` instance
4647
'''
48+
if max_model_len is None:
49+
max_model_len = max_num_batched_tokens
4750
scheduler_config = SchedulerConfig(
4851
max_num_seqs=max_num_seqs,
4952
max_num_batched_tokens=max_num_batched_tokens,
50-
max_model_len=max_num_batched_tokens,
53+
max_model_len=max_model_len,
5154
long_prefill_token_threshold=long_prefill_token_threshold,
5255
disable_chunked_mm_input=disable_chunked_mm_input,
56+
enable_chunked_prefill=True,
5357
)
5458
model_config = ModelConfig(
5559
model=model,
@@ -296,6 +300,7 @@ def test_no_mm_input_chunking():
296300
model="llava-hf/llava-1.5-7b-hf",
297301
max_num_batched_tokens=1024,
298302
disable_chunked_mm_input=True,
303+
max_model_len=2048,
299304
)
300305
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
301306
requests = create_requests(num_requests=1,
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Test whether spec decoding handles the max model length properly."""
3+
4+
import pytest
5+
6+
from vllm import LLM, SamplingParams
7+
8+
_PROMPTS = [
9+
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
10+
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
11+
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
12+
]
13+
14+
15+
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
16+
def test_ngram_max_len(
17+
monkeypatch: pytest.MonkeyPatch,
18+
num_speculative_tokens: int,
19+
):
20+
with monkeypatch.context() as m:
21+
m.setenv("VLLM_USE_V1", "1")
22+
23+
llm = LLM(
24+
model="facebook/opt-125m",
25+
max_model_len=100,
26+
enforce_eager=True, # For faster initialization.
27+
speculative_config={
28+
"method": "ngram",
29+
"prompt_lookup_max": 5,
30+
"prompt_lookup_min": 3,
31+
"num_speculative_tokens": num_speculative_tokens,
32+
},
33+
)
34+
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
35+
llm.generate(_PROMPTS, sampling_params)
36+
37+
38+
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
39+
def test_eagle_max_len(
40+
monkeypatch: pytest.MonkeyPatch,
41+
num_speculative_tokens: int,
42+
):
43+
with monkeypatch.context() as m:
44+
m.setenv("VLLM_USE_V1", "1")
45+
46+
llm = LLM(
47+
model="meta-llama/Meta-Llama-3-8B-Instruct",
48+
enforce_eager=True, # For faster initialization.
49+
speculative_config={
50+
"method": "eagle",
51+
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
52+
"num_speculative_tokens": num_speculative_tokens,
53+
},
54+
max_model_len=100,
55+
)
56+
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
57+
llm.generate(_PROMPTS, sampling_params)

tests/v1/spec_decode/test_ngram.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from vllm.config import SpeculativeConfig, VllmConfig
5+
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
66
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
77
_find_subarray_kmp,
88
_kmp_lps_array)
@@ -42,14 +42,24 @@ def test_find_subarray_kmp():
4242
def test_ngram_proposer():
4343

4444
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
45-
return NgramProposer(vllm_config=VllmConfig(
46-
speculative_config=SpeculativeConfig.from_dict(
47-
{
48-
"prompt_lookup_min": min_n,
49-
"prompt_lookup_max": max_n,
50-
"num_speculative_tokens": k,
51-
"method": "ngram",
52-
})))
45+
# Dummy model config. Just to set max_model_len.
46+
model_config = ModelConfig(model="facebook/opt-125m",
47+
task="generate",
48+
max_model_len=100,
49+
tokenizer="facebook/opt-125m",
50+
tokenizer_mode="auto",
51+
dtype="auto",
52+
seed=None,
53+
trust_remote_code=False)
54+
return NgramProposer(
55+
vllm_config=VllmConfig(model_config=model_config,
56+
speculative_config=SpeculativeConfig.
57+
from_dict({
58+
"prompt_lookup_min": min_n,
59+
"prompt_lookup_max": max_n,
60+
"num_speculative_tokens": k,
61+
"method": "ngram",
62+
})))
5363

5464
# No match.
5565
result = ngram_proposer(

vllm/v1/core/sched/scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,13 @@ def schedule(self) -> SchedulerOutput:
185185
num_new_tokens = min(num_new_tokens, token_budget)
186186
assert num_new_tokens > 0
187187

188+
# Make sure the input position does not exceed the max model len.
189+
# This is necessary when using spec decoding.
190+
num_new_tokens = min(
191+
num_new_tokens,
192+
self.max_model_len - request.num_computed_tokens)
193+
assert num_new_tokens > 0
194+
188195
# Schedule encoder inputs.
189196
if request.has_encoder_inputs:
190197
(encoder_inputs_to_schedule, num_new_tokens,

vllm/v1/spec_decode/eagle.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
1313
from vllm.v1.sample.metadata import SamplingMetadata
1414

15+
PADDING_SLOT_ID = -1
16+
1517

1618
class EagleProposer:
1719

@@ -23,6 +25,7 @@ def __init__(
2325
self.vllm_config = vllm_config
2426
self.num_speculative_tokens = (
2527
vllm_config.speculative_config.num_speculative_tokens)
28+
self.max_model_len = vllm_config.model_config.max_model_len
2629
self.block_size = vllm_config.cache_config.block_size
2730
# We need +1 here because the arange is used to set query_start_loc,
2831
# which has one more element than batch_size.
@@ -112,22 +115,48 @@ def propose(
112115
# Update the inputs.
113116
input_ids = draft_token_ids_list[-1]
114117
positions += 1
118+
119+
# NOTE(woosuk): We should handle the case where the draft model
120+
# generates tokens beyond the max model length. Since it is complex
121+
# to remove such requests from the batch, we keep them in the batch
122+
# but adjust the position ids and slot mappings to avoid the
123+
# out-of-range access during the model execution. The draft tokens
124+
# generated with this adjustment should be ignored.
125+
exceeds_max_model_len = positions >= self.max_model_len
126+
# Mask out the position ids that exceed the max model length.
127+
# Otherwise, we may get out-of-range error in RoPE.
128+
clamped_positions = torch.where(exceeds_max_model_len, 0,
129+
positions)
130+
131+
# Increment the sequence lengths.
115132
attn_metadata.max_seq_len += 1
116133
attn_metadata.seq_lens += 1
134+
# Consider max model length.
135+
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
136+
self.max_model_len)
137+
# For the requests that exceed the max model length, we set the
138+
# sequence length to 1 to minimize their overheads in attention.
139+
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
140+
117141
# Compute the slot mapping.
118-
block_numbers = positions // self.block_size
142+
block_numbers = clamped_positions // self.block_size
119143
block_ids = block_table.gather(dim=1,
120144
index=block_numbers.view(-1, 1))
121145
block_ids = block_ids.view(-1)
122146
attn_metadata.slot_mapping = (block_ids * self.block_size +
123-
positions % self.block_size)
147+
clamped_positions % self.block_size)
148+
# Mask out the slot mappings that exceed the max model length.
149+
# Otherwise, the KV cache will be inadvertently updated with the
150+
# padding tokens.
151+
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
152+
PADDING_SLOT_ID)
124153

125154
# Run the model.
126155
with set_forward_context(attn_metadata, self.vllm_config):
127156
hidden_states = self.model(
128157
input_ids=input_ids,
129158
hidden_states=hidden_states,
130-
positions=positions,
159+
positions=clamped_positions,
131160
)
132161
logits = self.model.compute_logits(hidden_states, None)
133162
draft_token_ids, probs = compute_probs_and_sample_next_token(

vllm/v1/spec_decode/ngram_proposer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ def __init__(self, vllm_config: VllmConfig):
1818
# tokens follow the match, we will return the maximum amount of
1919
# tokens until the end.
2020
self.k = vllm_config.speculative_config.num_speculative_tokens
21+
# Maximum length of the model.
22+
self.max_model_len = vllm_config.model_config.max_model_len
23+
2124
# Trigger Numba JIT compilation for N-gram proposer.
2225
# This usually takes less than 1 second.
2326
self.propose(np.zeros(1024, dtype=np.int32))
@@ -50,9 +53,14 @@ def propose(
5053
followed that pattern. Here we will return [4,2,3] because
5154
we only have three tokens after the match.
5255
"""
56+
# Do not generate draft tokens beyond the max model length.
57+
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
58+
if k <= 0:
59+
return None
60+
5361
# TODO(woosuk): Optimize this.
5462
for n in range(self.max_n, self.min_n - 1, -1):
55-
result = _find_subarray_kmp(context_token_ids, n, self.k)
63+
result = _find_subarray_kmp(context_token_ids, n, k)
5664
if result is not None:
5765
return result
5866
return None

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,8 @@ def generate_draft_token_ids(
12711271
draft_token_ids.append([])
12721272
continue
12731273

1274-
# Skip requests that require top-p, top-k, etc.
1274+
# Skip requests that require sampling parameters that are not
1275+
# supported with speculative decoding.
12751276
req_id = self.input_batch.req_ids[i]
12761277
if not is_spec_decode_supported(req_id, self.input_batch):
12771278
draft_token_ids.append([])
@@ -1280,6 +1281,11 @@ def generate_draft_token_ids(
12801281
# Add sampled_token_ids to token_ids_cpu.
12811282
start_idx = self.input_batch.num_tokens_no_spec[i]
12821283
end_idx = start_idx + num_sampled_ids
1284+
if end_idx >= self.max_model_len:
1285+
# Skip requests that have already reached the max model length.
1286+
draft_token_ids.append([])
1287+
continue
1288+
12831289
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
12841290
drafter_output = self.drafter.propose(
12851291
self.input_batch.token_ids_cpu[i, :end_idx])

0 commit comments

Comments
 (0)