Skip to content

Commit 1e17ab4

Browse files
committed
Update test fixtures
1 parent a59cf19 commit 1e17ab4

File tree

3 files changed

+74
-47
lines changed

3 files changed

+74
-47
lines changed

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,30 +2331,32 @@ async def create_batch_job_bundle(
23312331
return batch_bundle
23322332

23332333
async def execute(
2334-
self, user: User, _request: CreateBatchCompletionsRequest
2334+
self, user: User, request: CreateBatchCompletionsRequest
23352335
) -> CreateBatchCompletionsResponse:
2336-
hardware = infer_hardware_from_model_name(_request.model_config.model)
2336+
hardware = infer_hardware_from_model_name(request.model_config.model)
23372337
# Reconcile gpus count with num_shards from request
23382338
assert hardware.gpus is not None
2339-
if _request.model_config.num_shards:
2340-
hardware.gpus = max(hardware.gpus, _request.model_config.num_shards)
2339+
if request.model_config.num_shards:
2340+
hardware.gpus = max(hardware.gpus, request.model_config.num_shards)
23412341

2342-
request = CreateBatchCompletionsEngineRequest.from_api(_request)
2343-
request.model_config.num_shards = hardware.gpus
2342+
engine_request = CreateBatchCompletionsEngineRequest.from_api(request)
2343+
engine_request.model_config.num_shards = hardware.gpus
23442344

2345-
if request.tool_config and request.tool_config.name != "code_evaluator":
2345+
if engine_request.tool_config and engine_request.tool_config.name != "code_evaluator":
23462346
raise ObjectHasInvalidValueException(
23472347
"Only code_evaluator tool is supported for batch completions."
23482348
)
23492349

23502350
additional_engine_args = infer_addition_engine_args_from_model_name(
2351-
request.model_config.model
2351+
engine_request.model_config.model
23522352
)
23532353

23542354
if additional_engine_args.gpu_memory_utilization is not None:
2355-
request.max_gpu_memory_utilization = additional_engine_args.gpu_memory_utilization
2355+
engine_request.max_gpu_memory_utilization = (
2356+
additional_engine_args.gpu_memory_utilization
2357+
)
23562358

2357-
batch_bundle = await self.create_batch_job_bundle(user, request, hardware)
2359+
batch_bundle = await self.create_batch_job_bundle(user, engine_request, hardware)
23582360

23592361
validate_resource_requests(
23602362
bundle=batch_bundle,
@@ -2365,21 +2367,21 @@ async def execute(
23652367
gpu_type=hardware.gpu_type,
23662368
)
23672369

2368-
if request.max_runtime_sec is None or request.max_runtime_sec < 1:
2370+
if engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1:
23692371
raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.")
23702372

23712373
job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job(
23722374
created_by=user.user_id,
23732375
owner=user.team_id,
2374-
job_config=request.dict(),
2376+
job_config=engine_request.dict(),
23752377
env=batch_bundle.env,
23762378
command=batch_bundle.command,
23772379
repo=batch_bundle.image_repository,
23782380
tag=batch_bundle.image_tag,
23792381
resource_requests=hardware,
2380-
labels=request.model_config.labels,
2382+
labels=engine_request.model_config.labels,
23812383
mount_location=batch_bundle.mount_location,
2382-
override_job_max_runtime_s=request.max_runtime_sec,
2383-
num_workers=request.data_parallelism,
2384+
override_job_max_runtime_s=engine_request.max_runtime_sec,
2385+
num_workers=engine_request.data_parallelism,
23842386
)
23852387
return CreateBatchCompletionsResponse(job_id=job_id)

model-engine/tests/unit/inference/conftest.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from model_engine_server.common.dtos.llms import (
55
CompletionOutput,
6+
CreateBatchCompletionsEngineRequest,
67
CreateBatchCompletionsModelConfig,
78
CreateBatchCompletionsRequest,
89
CreateBatchCompletionsRequestContent,
@@ -12,14 +13,20 @@
1213

1314

1415
@pytest.fixture
15-
def create_batch_completions_request():
16-
return CreateBatchCompletionsRequest(
16+
def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineRequest:
17+
return CreateBatchCompletionsEngineRequest(
18+
input_data_path="input_data_path",
19+
output_data_path="output_data_path",
1720
model_config=CreateBatchCompletionsModelConfig(
18-
checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={}
21+
model="model",
22+
checkpoint_path="checkpoint_path",
23+
labels={},
24+
seed=123,
25+
num_shards=4,
1926
),
2027
data_parallelism=1,
21-
input_data_path="input_data_path",
22-
output_data_path="output_data_path",
28+
max_runtime_sec=86400,
29+
max_gpu_memory_utilization=0.95,
2330
)
2431

2532

model-engine/tests/unit/inference/test_vllm_batch.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
@pytest.mark.asyncio
99
@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine")
10-
@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest")
10+
@patch(
11+
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
12+
)
1113
@patch(
1214
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
1315
)
@@ -25,9 +27,9 @@ async def test_batch_inference(
2527
mock_get_s3_client,
2628
mock_generate_with_vllm,
2729
mock_create_batch_completions_request_content,
28-
mock_create_batch_completions_request,
30+
mock_create_batch_completions_engine_request,
2931
mock_vllm,
30-
create_batch_completions_request,
32+
create_batch_completions_engine_request,
3133
create_batch_completions_request_content,
3234
mock_s3_client,
3335
mock_process,
@@ -36,7 +38,9 @@ async def test_batch_inference(
3638
# Mock the necessary objects and data
3739
mock_popen.return_value = mock_process
3840
mock_get_s3_client.return_value = mock_s3_client
39-
mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request
41+
mock_create_batch_completions_engine_request.parse_file.return_value = (
42+
create_batch_completions_engine_request
43+
)
4044
mock_create_batch_completions_request_content.parse_raw.return_value = (
4145
create_batch_completions_request_content
4246
)
@@ -48,7 +52,7 @@ async def test_batch_inference(
4852
await batch_inference()
4953

5054
# Assertions
51-
mock_create_batch_completions_request.parse_file.assert_called_once()
55+
mock_create_batch_completions_engine_request.parse_file.assert_called_once()
5256
mock_open_func.assert_has_calls(
5357
[
5458
call("input_data_path", "r"),
@@ -61,7 +65,9 @@ async def test_batch_inference(
6165

6266
@pytest.mark.asyncio
6367
@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine")
64-
@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest")
68+
@patch(
69+
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
70+
)
6571
@patch(
6672
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
6773
)
@@ -79,9 +85,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed(
7985
mock_get_s3_client,
8086
mock_generate_with_vllm,
8187
mock_create_batch_completions_request_content,
82-
mock_create_batch_completions_request,
88+
mock_create_batch_completions_engine_request,
8389
mock_vllm,
84-
create_batch_completions_request,
90+
create_batch_completions_engine_request,
8591
create_batch_completions_request_content,
8692
mock_s3_client,
8793
mock_process,
@@ -91,7 +97,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed(
9197
mock_process.returncode = 1 # Failed to download model
9298
mock_popen.return_value = mock_process
9399
mock_get_s3_client.return_value = mock_s3_client
94-
mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request
100+
mock_create_batch_completions_engine_request.parse_file.return_value = (
101+
create_batch_completions_engine_request
102+
)
95103
mock_create_batch_completions_request_content.parse_raw.return_value = (
96104
create_batch_completions_request_content
97105
)
@@ -103,7 +111,7 @@ async def test_batch_inference_failed_to_download_model_but_proceed(
103111
await batch_inference()
104112

105113
# Assertions
106-
mock_create_batch_completions_request.parse_file.assert_called_once()
114+
mock_create_batch_completions_engine_request.parse_file.assert_called_once()
107115
mock_open_func.assert_has_calls(
108116
[
109117
call("input_data_path", "r"),
@@ -116,7 +124,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed(
116124

117125
@pytest.mark.asyncio
118126
@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine")
119-
@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest")
127+
@patch(
128+
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
129+
)
120130
@patch(
121131
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
122132
)
@@ -136,9 +146,9 @@ async def test_batch_inference_two_workers(
136146
mock_get_s3_client,
137147
mock_generate_with_vllm,
138148
mock_create_batch_completions_request_content,
139-
mock_create_batch_completions_request,
149+
mock_create_batch_completions_engine_request,
140150
mock_vllm,
141-
create_batch_completions_request,
151+
create_batch_completions_engine_request,
142152
create_batch_completions_request_content,
143153
mock_s3_client,
144154
mock_process,
@@ -147,8 +157,10 @@ async def test_batch_inference_two_workers(
147157
# Mock the necessary objects and data
148158
mock_popen.return_value = mock_process
149159
mock_get_s3_client.return_value = mock_s3_client
150-
create_batch_completions_request.data_parallelism = 2
151-
mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request
160+
create_batch_completions_engine_request.data_parallelism = 2
161+
mock_create_batch_completions_engine_request.parse_file.return_value = (
162+
create_batch_completions_engine_request
163+
)
152164
mock_create_batch_completions_request_content.parse_raw.return_value = (
153165
create_batch_completions_request_content
154166
)
@@ -168,7 +180,7 @@ def side_effect(key, default):
168180
await batch_inference()
169181

170182
# Assertions
171-
mock_create_batch_completions_request.parse_file.assert_called_once()
183+
mock_create_batch_completions_engine_request.parse_file.assert_called_once()
172184
mock_open_func.assert_has_calls(
173185
[
174186
call("input_data_path", "r"),
@@ -198,7 +210,9 @@ def side_effect(key, default):
198210

199211
@pytest.mark.asyncio
200212
@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine")
201-
@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest")
213+
@patch(
214+
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
215+
)
202216
@patch(
203217
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
204218
)
@@ -218,9 +232,9 @@ async def test_batch_inference_delete_chunks(
218232
mock_get_s3_client,
219233
mock_generate_with_vllm,
220234
mock_create_batch_completions_request_content,
221-
mock_create_batch_completions_request,
235+
mock_create_batch_completions_engine_request,
222236
mock_vllm,
223-
create_batch_completions_request,
237+
create_batch_completions_engine_request,
224238
create_batch_completions_request_content,
225239
mock_s3_client,
226240
mock_process,
@@ -229,9 +243,11 @@ async def test_batch_inference_delete_chunks(
229243
# Mock the necessary objects and data
230244
mock_popen.return_value = mock_process
231245
mock_get_s3_client.return_value = mock_s3_client
232-
create_batch_completions_request.data_parallelism = 2
233-
create_batch_completions_request.output_data_path = "s3://bucket/key"
234-
mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request
246+
create_batch_completions_engine_request.data_parallelism = 2
247+
create_batch_completions_engine_request.output_data_path = "s3://bucket/key"
248+
mock_create_batch_completions_engine_request.parse_file.return_value = (
249+
create_batch_completions_engine_request
250+
)
235251
mock_create_batch_completions_request_content.parse_raw.return_value = (
236252
create_batch_completions_request_content
237253
)
@@ -251,7 +267,7 @@ def side_effect(key, default):
251267
await batch_inference()
252268

253269
# Assertions
254-
mock_create_batch_completions_request.parse_file.assert_called_once()
270+
mock_create_batch_completions_engine_request.parse_file.assert_called_once()
255271
mock_open_func.assert_has_calls(
256272
[
257273
call("input_data_path", "r"),
@@ -310,7 +326,9 @@ def test_file_exists_no_such_key():
310326

311327
@pytest.mark.asyncio
312328
@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine")
313-
@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest")
329+
@patch(
330+
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
331+
)
314332
@patch(
315333
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
316334
)
@@ -330,7 +348,7 @@ async def test_batch_inference_tool_completion(
330348
mock_get_s3_client,
331349
mock_generate_with_vllm,
332350
mock_create_batch_completions_request_content,
333-
mock_create_batch_completions_request,
351+
mock_create_batch_completions_engine_request,
334352
mock_vllm,
335353
create_batch_completions_tool_completion_request,
336354
create_batch_completions_tool_completion_request_content,
@@ -344,7 +362,7 @@ async def test_batch_inference_tool_completion(
344362
mock_run.return_value = mock_run_output
345363
mock_popen.return_value = mock_process
346364
mock_get_s3_client.return_value = mock_s3_client
347-
mock_create_batch_completions_request.parse_file.return_value = (
365+
mock_create_batch_completions_engine_request.parse_file.return_value = (
348366
create_batch_completions_tool_completion_request
349367
)
350368
mock_create_batch_completions_request_content.parse_raw.return_value = (
@@ -361,7 +379,7 @@ async def test_batch_inference_tool_completion(
361379
await batch_inference()
362380

363381
# Assertions
364-
mock_create_batch_completions_request.parse_file.assert_called_once()
382+
mock_create_batch_completions_engine_request.parse_file.assert_called_once()
365383
mock_open_func.assert_has_calls(
366384
[
367385
call("input_data_path", "r"),

0 commit comments

Comments
 (0)