Skip to content

Commit a59cf19

Browse files
committed
Separate Gateway DTO and engine DTO
1 parent 6764c4b commit a59cf19

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -528,17 +528,37 @@ class CreateBatchCompletionsRequest(BaseModel):
528528
"""
529529
Maximum runtime of the batch inference in seconds. Default to one day.
530530
"""
531-
max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0)
532-
"""
533-
Maximum GPU memory utilization for the batch inference. Default to 90%.
534-
"""
535531
tool_config: Optional[ToolConfig] = None
536532
"""
537533
Configuration for tool use.
538534
NOTE: this config is highly experimental and signature will change significantly in future iterations.
539535
"""
540536

541537

538+
class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest):
539+
"""
540+
Internal model for representing request to the llm engine. This contains additional fields that we want
541+
hidden from the DTO exposed to the client.
542+
"""
543+
544+
max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0)
545+
"""
546+
Maximum GPU memory utilization for the batch inference. Default to 90%.
547+
"""
548+
549+
@staticmethod
550+
def from_api(request: CreateBatchCompletionsRequest) -> "CreateBatchCompletionsEngineRequest":
551+
return CreateBatchCompletionsEngineRequest(
552+
input_data_path=request.input_data_path,
553+
output_data_path=request.output_data_path,
554+
content=request.content,
555+
model_config=request.model_config,
556+
data_parallelism=request.data_parallelism,
557+
max_runtime_sec=request.max_runtime_sec,
558+
tool_config=request.tool_config,
559+
)
560+
561+
542562
class CreateBatchCompletionsResponse(BaseModel):
543563
job_id: str
544564

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
CompletionStreamV1Response,
2222
CompletionSyncV1Request,
2323
CompletionSyncV1Response,
24+
CreateBatchCompletionsEngineRequest,
2425
CreateBatchCompletionsRequest,
2526
CreateBatchCompletionsResponse,
2627
CreateLLMModelEndpointV1Request,
@@ -2330,13 +2331,15 @@ async def create_batch_job_bundle(
23302331
return batch_bundle
23312332

23322333
async def execute(
2333-
self, user: User, request: CreateBatchCompletionsRequest
2334+
self, user: User, _request: CreateBatchCompletionsRequest
23342335
) -> CreateBatchCompletionsResponse:
2335-
hardware = infer_hardware_from_model_name(request.model_config.model)
2336+
hardware = infer_hardware_from_model_name(_request.model_config.model)
23362337
# Reconcile gpus count with num_shards from request
23372338
assert hardware.gpus is not None
2338-
if request.model_config.num_shards:
2339-
hardware.gpus = max(hardware.gpus, request.model_config.num_shards)
2339+
if _request.model_config.num_shards:
2340+
hardware.gpus = max(hardware.gpus, _request.model_config.num_shards)
2341+
2342+
request = CreateBatchCompletionsEngineRequest.from_api(_request)
23402343
request.model_config.num_shards = hardware.gpus
23412344

23422345
if request.tool_config and request.tool_config.name != "code_evaluator":
@@ -2347,6 +2350,7 @@ async def execute(
23472350
additional_engine_args = infer_addition_engine_args_from_model_name(
23482351
request.model_config.model
23492352
)
2353+
23502354
if additional_engine_args.gpu_memory_utilization is not None:
23512355
request.max_gpu_memory_utilization = additional_engine_args.gpu_memory_utilization
23522356

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from func_timeout import FunctionTimedOut, func_set_timeout
1616
from model_engine_server.common.dtos.llms import (
1717
CompletionOutput,
18-
CreateBatchCompletionsRequest,
18+
CreateBatchCompletionsEngineRequest,
1919
CreateBatchCompletionsRequestContent,
2020
TokenOutput,
2121
ToolConfig,
@@ -145,7 +145,7 @@ def random_uuid() -> str:
145145
return str(uuid.uuid4().hex)
146146

147147

148-
def get_vllm_engine(model, request: CreateBatchCompletionsRequest):
148+
def get_vllm_engine(model: str, request: CreateBatchCompletionsEngineRequest):
149149
from vllm import AsyncEngineArgs, AsyncLLMEngine
150150

151151
engine_args = AsyncEngineArgs(
@@ -313,7 +313,7 @@ def tool_func(text: str, past_context: Optional[str]):
313313
async def batch_inference():
314314
job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0))
315315

316-
request = CreateBatchCompletionsRequest.parse_file(CONFIG_FILE)
316+
request = CreateBatchCompletionsEngineRequest.parse_file(CONFIG_FILE)
317317

318318
if request.model_config.checkpoint_path is not None:
319319
download_model(request.model_config.checkpoint_path, MODEL_WEIGHTS_FOLDER)

0 commit comments

Comments
 (0)