|
24 | 24 | from torch import nn
|
25 | 25 | from transformers import Gemma3TextConfig
|
26 | 26 |
|
27 |
| -from vllm.attention import Attention |
| 27 | +from vllm.attention import Attention, AttentionType |
28 | 28 | from vllm.compilation.decorators import support_torch_compile
|
29 | 29 | from vllm.config import CacheConfig, VllmConfig
|
30 | 30 | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
|
44 | 44 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
45 | 45 | from vllm.sequence import IntermediateTensors
|
46 | 46 |
|
| 47 | +from ...attention.layers.encoder_only_attention import EncoderOnlyAttention |
47 | 48 | from .interfaces import SupportsLoRA, SupportsPP
|
48 | 49 | from .utils import (AutoWeightsLoader, extract_layer_index,
|
49 | 50 | is_pp_missing_parameter,
|
@@ -169,16 +170,24 @@ def __init__(self,
|
169 | 170 | rope_scaling=self.rope_scaling,
|
170 | 171 | )
|
171 | 172 |
|
172 |
| - # Initialize the attention. |
173 |
| - self.attn = Attention(self.num_heads, |
174 |
| - self.head_dim, |
175 |
| - self.scaling, |
176 |
| - num_kv_heads=self.num_kv_heads, |
177 |
| - cache_config=cache_config, |
178 |
| - quant_config=quant_config, |
179 |
| - logits_soft_cap=attn_logits_soft_cap, |
180 |
| - per_layer_sliding_window=sliding_window, |
181 |
| - prefix=f"{prefix}.attn") |
| 173 | + if getattr(config, "is_causal", True): |
| 174 | + attn_type = AttentionType.DECODER |
| 175 | + else: |
| 176 | + attn_type = AttentionType.ENCODER_ONLY |
| 177 | + |
| 178 | + attn_cls = (EncoderOnlyAttention |
| 179 | + if attn_type == AttentionType.ENCODER_ONLY else Attention) |
| 180 | + |
| 181 | + self.attn = attn_cls(self.num_heads, |
| 182 | + self.head_dim, |
| 183 | + self.scaling, |
| 184 | + num_kv_heads=self.num_kv_heads, |
| 185 | + cache_config=cache_config, |
| 186 | + quant_config=quant_config, |
| 187 | + attn_type=attn_type, |
| 188 | + logits_soft_cap=attn_logits_soft_cap, |
| 189 | + per_layer_sliding_window=sliding_window, |
| 190 | + prefix=f"{prefix}.attn") |
182 | 191 |
|
183 | 192 | def forward(
|
184 | 193 | self,
|
|
0 commit comments