12
12
from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
13
13
from vllm .v1 .sample .metadata import SamplingMetadata
14
14
15
+ PADDING_SLOT_ID = - 1
16
+
15
17
16
18
class EagleProposer :
17
19
@@ -23,6 +25,7 @@ def __init__(
23
25
self .vllm_config = vllm_config
24
26
self .num_speculative_tokens = (
25
27
vllm_config .speculative_config .num_speculative_tokens )
28
+ self .max_model_len = vllm_config .model_config .max_model_len
26
29
self .block_size = vllm_config .cache_config .block_size
27
30
# We need +1 here because the arange is used to set query_start_loc,
28
31
# which has one more element than batch_size.
@@ -112,22 +115,48 @@ def propose(
112
115
# Update the inputs.
113
116
input_ids = draft_token_ids_list [- 1 ]
114
117
positions += 1
118
+
119
+ # NOTE(woosuk): We should handle the case where the draft model
120
+ # generates tokens beyond the max model length. Since it is complex
121
+ # to remove such requests from the batch, we keep them in the batch
122
+ # but adjust the position ids and slot mappings to avoid the
123
+ # out-of-range access during the model execution. The draft tokens
124
+ # generated with this adjustment should be ignored.
125
+ exceeds_max_model_len = positions >= self .max_model_len
126
+ # Mask out the position ids that exceed the max model length.
127
+ # Otherwise, we may get out-of-range error in RoPE.
128
+ clamped_positions = torch .where (exceeds_max_model_len , 0 ,
129
+ positions )
130
+
131
+ # Increment the sequence lengths.
115
132
attn_metadata .max_seq_len += 1
116
133
attn_metadata .seq_lens += 1
134
+ # Consider max model length.
135
+ attn_metadata .max_seq_len = min (attn_metadata .max_seq_len ,
136
+ self .max_model_len )
137
+ # For the requests that exceed the max model length, we set the
138
+ # sequence length to 1 to minimize their overheads in attention.
139
+ attn_metadata .seq_lens .masked_fill_ (exceeds_max_model_len , 1 )
140
+
117
141
# Compute the slot mapping.
118
- block_numbers = positions // self .block_size
142
+ block_numbers = clamped_positions // self .block_size
119
143
block_ids = block_table .gather (dim = 1 ,
120
144
index = block_numbers .view (- 1 , 1 ))
121
145
block_ids = block_ids .view (- 1 )
122
146
attn_metadata .slot_mapping = (block_ids * self .block_size +
123
- positions % self .block_size )
147
+ clamped_positions % self .block_size )
148
+ # Mask out the slot mappings that exceed the max model length.
149
+ # Otherwise, the KV cache will be inadvertently updated with the
150
+ # padding tokens.
151
+ attn_metadata .slot_mapping .masked_fill_ (exceeds_max_model_len ,
152
+ PADDING_SLOT_ID )
124
153
125
154
# Run the model.
126
155
with set_forward_context (attn_metadata , self .vllm_config ):
127
156
hidden_states = self .model (
128
157
input_ids = input_ids ,
129
158
hidden_states = hidden_states ,
130
- positions = positions ,
159
+ positions = clamped_positions ,
131
160
)
132
161
logits = self .model .compute_logits (hidden_states , None )
133
162
draft_token_ids , probs = compute_probs_and_sample_next_token (
0 commit comments