9
9
import math
10
10
import os
11
11
import re
12
- from dataclasses import asdict
12
+ from dataclasses import asdict , dataclass
13
13
from typing import Any , AsyncIterable , Dict , List , Optional , Union
14
14
15
15
from model_engine_server .common .config import hmi_config
21
21
CompletionStreamV1Response ,
22
22
CompletionSyncV1Request ,
23
23
CompletionSyncV1Response ,
24
+ CreateBatchCompletionsEngineRequest ,
24
25
CreateBatchCompletionsRequest ,
25
26
CreateBatchCompletionsResponse ,
26
27
CreateLLMModelEndpointV1Request ,
@@ -2200,6 +2201,27 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl
2200
2201
return ModelDownloadResponse (urls = urls )
2201
2202
2202
2203
2204
+ @dataclass
2205
+ class VLLMEngineArgs :
2206
+ gpu_memory_utilization : Optional [float ] = None
2207
+
2208
+
2209
+ def infer_addition_engine_args_from_model_name (model_name : str ) -> VLLMEngineArgs :
2210
+ numbers = re .findall (r"\d+" , model_name )
2211
+ if len (numbers ) == 0 :
2212
+ raise ObjectHasInvalidValueException (
2213
+ f"Model { model_name } is not supported for batch completions."
2214
+ )
2215
+
2216
+ b_params = int (numbers [- 1 ])
2217
+ if b_params >= 70 :
2218
+ gpu_memory_utilization = 0.95
2219
+ else :
2220
+ gpu_memory_utilization = 0.9
2221
+
2222
+ return VLLMEngineArgs (gpu_memory_utilization = gpu_memory_utilization )
2223
+
2224
+
2203
2225
def infer_hardware_from_model_name (model_name : str ) -> CreateDockerImageBatchJobResourceRequests :
2204
2226
if "mixtral-8x7b" in model_name :
2205
2227
cpus = "20"
@@ -2324,14 +2346,25 @@ async def execute(
2324
2346
assert hardware .gpus is not None
2325
2347
if request .model_config .num_shards :
2326
2348
hardware .gpus = max (hardware .gpus , request .model_config .num_shards )
2327
- request .model_config .num_shards = hardware .gpus
2328
2349
2329
- if request .tool_config and request .tool_config .name != "code_evaluator" :
2350
+ engine_request = CreateBatchCompletionsEngineRequest .from_api (request )
2351
+ engine_request .model_config .num_shards = hardware .gpus
2352
+
2353
+ if engine_request .tool_config and engine_request .tool_config .name != "code_evaluator" :
2330
2354
raise ObjectHasInvalidValueException (
2331
2355
"Only code_evaluator tool is supported for batch completions."
2332
2356
)
2333
2357
2334
- batch_bundle = await self .create_batch_job_bundle (user , request , hardware )
2358
+ additional_engine_args = infer_addition_engine_args_from_model_name (
2359
+ engine_request .model_config .model
2360
+ )
2361
+
2362
+ if additional_engine_args .gpu_memory_utilization is not None :
2363
+ engine_request .max_gpu_memory_utilization = (
2364
+ additional_engine_args .gpu_memory_utilization
2365
+ )
2366
+
2367
+ batch_bundle = await self .create_batch_job_bundle (user , engine_request , hardware )
2335
2368
2336
2369
validate_resource_requests (
2337
2370
bundle = batch_bundle ,
@@ -2342,21 +2375,21 @@ async def execute(
2342
2375
gpu_type = hardware .gpu_type ,
2343
2376
)
2344
2377
2345
- if request .max_runtime_sec is None or request .max_runtime_sec < 1 :
2378
+ if engine_request .max_runtime_sec is None or engine_request .max_runtime_sec < 1 :
2346
2379
raise ObjectHasInvalidValueException ("max_runtime_sec must be a positive integer." )
2347
2380
2348
2381
job_id = await self .docker_image_batch_job_gateway .create_docker_image_batch_job (
2349
2382
created_by = user .user_id ,
2350
2383
owner = user .team_id ,
2351
- job_config = request .dict (),
2384
+ job_config = engine_request .dict (),
2352
2385
env = batch_bundle .env ,
2353
2386
command = batch_bundle .command ,
2354
2387
repo = batch_bundle .image_repository ,
2355
2388
tag = batch_bundle .image_tag ,
2356
2389
resource_requests = hardware ,
2357
- labels = request .model_config .labels ,
2390
+ labels = engine_request .model_config .labels ,
2358
2391
mount_location = batch_bundle .mount_location ,
2359
- override_job_max_runtime_s = request .max_runtime_sec ,
2360
- num_workers = request .data_parallelism ,
2392
+ override_job_max_runtime_s = engine_request .max_runtime_sec ,
2393
+ num_workers = engine_request .data_parallelism ,
2361
2394
)
2362
2395
return CreateBatchCompletionsResponse (job_id = job_id )
0 commit comments