Skip to content

Commit fbe7417

Browse files
authored
Increase max gpu utilization for 70b models (#517)
* Increase max gpu utilization for 70b models * Separate Gateway DTO and engine DTO * Update test fixtures
1 parent c019a6a commit fbe7417

File tree

5 files changed

+127
-45
lines changed

5 files changed

+127
-45
lines changed

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,30 @@ class CreateBatchCompletionsRequest(BaseModel):
535535
"""
536536

537537

538+
class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest):
539+
"""
540+
Internal model for representing request to the llm engine. This contains additional fields that we want
541+
hidden from the DTO exposed to the client.
542+
"""
543+
544+
max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0)
545+
"""
546+
Maximum GPU memory utilization for the batch inference. Default to 90%.
547+
"""
548+
549+
@staticmethod
550+
def from_api(request: CreateBatchCompletionsRequest) -> "CreateBatchCompletionsEngineRequest":
551+
return CreateBatchCompletionsEngineRequest(
552+
input_data_path=request.input_data_path,
553+
output_data_path=request.output_data_path,
554+
content=request.content,
555+
model_config=request.model_config,
556+
data_parallelism=request.data_parallelism,
557+
max_runtime_sec=request.max_runtime_sec,
558+
tool_config=request.tool_config,
559+
)
560+
561+
538562
class CreateBatchCompletionsResponse(BaseModel):
539563
job_id: str
540564

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import math
1010
import os
1111
import re
12-
from dataclasses import asdict
12+
from dataclasses import asdict, dataclass
1313
from typing import Any, AsyncIterable, Dict, List, Optional, Union
1414

1515
from model_engine_server.common.config import hmi_config
@@ -21,6 +21,7 @@
2121
CompletionStreamV1Response,
2222
CompletionSyncV1Request,
2323
CompletionSyncV1Response,
24+
CreateBatchCompletionsEngineRequest,
2425
CreateBatchCompletionsRequest,
2526
CreateBatchCompletionsResponse,
2627
CreateLLMModelEndpointV1Request,
@@ -2200,6 +2201,27 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl
22002201
return ModelDownloadResponse(urls=urls)
22012202

22022203

2204+
@dataclass
2205+
class VLLMEngineArgs:
2206+
gpu_memory_utilization: Optional[float] = None
2207+
2208+
2209+
def infer_addition_engine_args_from_model_name(model_name: str) -> VLLMEngineArgs:
2210+
numbers = re.findall(r"\d+", model_name)
2211+
if len(numbers) == 0:
2212+
raise ObjectHasInvalidValueException(
2213+
f"Model {model_name} is not supported for batch completions."
2214+
)
2215+
2216+
b_params = int(numbers[-1])
2217+
if b_params >= 70:
2218+
gpu_memory_utilization = 0.95
2219+
else:
2220+
gpu_memory_utilization = 0.9
2221+
2222+
return VLLMEngineArgs(gpu_memory_utilization=gpu_memory_utilization)
2223+
2224+
22032225
def infer_hardware_from_model_name(model_name: str) -> CreateDockerImageBatchJobResourceRequests:
22042226
if "mixtral-8x7b" in model_name:
22052227
cpus = "20"
@@ -2324,14 +2346,25 @@ async def execute(
23242346
assert hardware.gpus is not None
23252347
if request.model_config.num_shards:
23262348
hardware.gpus = max(hardware.gpus, request.model_config.num_shards)
2327-
request.model_config.num_shards = hardware.gpus
23282349

2329-
if request.tool_config and request.tool_config.name != "code_evaluator":
2350+
engine_request = CreateBatchCompletionsEngineRequest.from_api(request)
2351+
engine_request.model_config.num_shards = hardware.gpus
2352+
2353+
if engine_request.tool_config and engine_request.tool_config.name != "code_evaluator":
23302354
raise ObjectHasInvalidValueException(
23312355
"Only code_evaluator tool is supported for batch completions."
23322356
)
23332357

2334-
batch_bundle = await self.create_batch_job_bundle(user, request, hardware)
2358+
additional_engine_args = infer_addition_engine_args_from_model_name(
2359+
engine_request.model_config.model
2360+
)
2361+
2362+
if additional_engine_args.gpu_memory_utilization is not None:
2363+
engine_request.max_gpu_memory_utilization = (
2364+
additional_engine_args.gpu_memory_utilization
2365+
)
2366+
2367+
batch_bundle = await self.create_batch_job_bundle(user, engine_request, hardware)
23352368

23362369
validate_resource_requests(
23372370
bundle=batch_bundle,
@@ -2342,21 +2375,21 @@ async def execute(
23422375
gpu_type=hardware.gpu_type,
23432376
)
23442377

2345-
if request.max_runtime_sec is None or request.max_runtime_sec < 1:
2378+
if engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1:
23462379
raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.")
23472380

23482381
job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job(
23492382
created_by=user.user_id,
23502383
owner=user.team_id,
2351-
job_config=request.dict(),
2384+
job_config=engine_request.dict(),
23522385
env=batch_bundle.env,
23532386
command=batch_bundle.command,
23542387
repo=batch_bundle.image_repository,
23552388
tag=batch_bundle.image_tag,
23562389
resource_requests=hardware,
2357-
labels=request.model_config.labels,
2390+
labels=engine_request.model_config.labels,
23582391
mount_location=batch_bundle.mount_location,
2359-
override_job_max_runtime_s=request.max_runtime_sec,
2360-
num_workers=request.data_parallelism,
2392+
override_job_max_runtime_s=engine_request.max_runtime_sec,
2393+
num_workers=engine_request.data_parallelism,
23612394
)
23622395
return CreateBatchCompletionsResponse(job_id=job_id)

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from func_timeout import FunctionTimedOut, func_set_timeout
1616
from model_engine_server.common.dtos.llms import (
1717
CompletionOutput,
18-
CreateBatchCompletionsRequest,
18+
CreateBatchCompletionsEngineRequest,
1919
CreateBatchCompletionsRequestContent,
2020
TokenOutput,
2121
ToolConfig,
@@ -145,7 +145,7 @@ def random_uuid() -> str:
145145
return str(uuid.uuid4().hex)
146146

147147

148-
def get_vllm_engine(model, request):
148+
def get_vllm_engine(model: str, request: CreateBatchCompletionsEngineRequest):
149149
from vllm import AsyncEngineArgs, AsyncLLMEngine
150150

151151
engine_args = AsyncEngineArgs(
@@ -154,7 +154,7 @@ def get_vllm_engine(model, request):
154154
tensor_parallel_size=request.model_config.num_shards,
155155
seed=request.model_config.seed or 0,
156156
disable_log_requests=True,
157-
gpu_memory_utilization=0.9,
157+
gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9,
158158
)
159159

160160
llm = AsyncLLMEngine.from_engine_args(engine_args)
@@ -313,7 +313,7 @@ def tool_func(text: str, past_context: Optional[str]):
313313
async def batch_inference():
314314
job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0))
315315

316-
request = CreateBatchCompletionsRequest.parse_file(CONFIG_FILE)
316+
request = CreateBatchCompletionsEngineRequest.parse_file(CONFIG_FILE)
317317

318318
if request.model_config.checkpoint_path is not None:
319319
download_model(request.model_config.checkpoint_path, MODEL_WEIGHTS_FOLDER)

model-engine/tests/unit/inference/conftest.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from model_engine_server.common.dtos.llms import (
55
CompletionOutput,
6+
CreateBatchCompletionsEngineRequest,
67
CreateBatchCompletionsModelConfig,
78
CreateBatchCompletionsRequest,
89
CreateBatchCompletionsRequestContent,
@@ -12,14 +13,20 @@
1213

1314

1415
@pytest.fixture
15-
def create_batch_completions_request():
16-
return CreateBatchCompletionsRequest(
16+
def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineRequest:
17+
return CreateBatchCompletionsEngineRequest(
18+
input_data_path="input_data_path",
19+
output_data_path="output_data_path",
1720
model_config=CreateBatchCompletionsModelConfig(
18-
checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={}
21+
model="model",
22+
checkpoint_path="checkpoint_path",
23+
labels={},
24+
seed=123,
25+
num_shards=4,
1926
),
2027
data_parallelism=1,
21-
input_data_path="input_data_path",
22-
output_data_path="output_data_path",
28+
max_runtime_sec=86400,
29+
max_gpu_memory_utilization=0.95,
2330
)
2431

2532

0 commit comments

Comments
 (0)