|
14 | 14 | # limitations under the License.
|
15 | 15 | """
|
16 | 16 |
|
| 17 | +import argparse |
17 | 18 | import json
|
18 | 19 | from dataclasses import asdict, dataclass
|
19 | 20 | from dataclasses import fields as dataclass_fields
|
@@ -190,7 +191,7 @@ class EngineArgs:
|
190 | 191 | """
|
191 | 192 | Flag to indicate whether to use warm-up before inference.
|
192 | 193 | """
|
193 |
| - enable_prefix_caching: bool = False |
| 194 | + enable_prefix_caching: bool = True |
194 | 195 | """
|
195 | 196 | Flag to enable prefix caching.
|
196 | 197 | """
|
@@ -387,6 +388,16 @@ def __post_init__(self):
|
387 | 388 | """
|
388 | 389 | if not self.tokenizer:
|
389 | 390 | self.tokenizer = self.model
|
| 391 | + if self.splitwise_role == "decode": |
| 392 | + self.enable_prefix_caching = False |
| 393 | + if self.speculative_config is not None: |
| 394 | + self.enable_prefix_caching = False |
| 395 | + if self.enable_mm: |
| 396 | + self.enable_prefix_caching = False |
| 397 | + if not current_platform.is_cuda(): |
| 398 | + self.enable_prefix_caching = False |
| 399 | + if self.dynamic_load_weight: |
| 400 | + self.enable_prefix_caching = False |
390 | 401 | if self.enable_logprob:
|
391 | 402 | if self.speculative_config is not None:
|
392 | 403 | raise NotImplementedError("Logprob does not support speculation_config.")
|
@@ -725,7 +736,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
725 | 736 | perf_group = parser.add_argument_group("Performance Tuning")
|
726 | 737 | perf_group.add_argument(
|
727 | 738 | "--enable-prefix-caching",
|
728 |
| - action="store_true", |
| 739 | + action=argparse.BooleanOptionalAction, |
729 | 740 | default=EngineArgs.enable_prefix_caching,
|
730 | 741 | help="Flag to enable prefix caching.",
|
731 | 742 | )
|
|
0 commit comments